"vscode:/vscode.git/clone" did not exist on "8331da46837be40f96fbd24de6a6fb2da28acd11"
test_heterograph-pickle.py 6.82 KB
Newer Older
1
2
3
4
5
6
7
8
import io
import pickle
import unittest

import backend as F

import dgl
import dgl.function as fn
9
import networkx as nx
10
import pytest
11
import scipy.sparse as ssp
Gan Quan's avatar
Gan Quan committed
12
13
from dgl.graph_index import create_graph_index
from dgl.utils import toindex
14
15
16
17
18
19
20
from utils import (
    assert_is_identical,
    assert_is_identical_hetero,
    check_graph_equal,
    get_cases,
    parametrize_idtype,
)
21

22

23
def _assert_is_identical_nodeflow(nf1, nf2):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
24
    assert nf1.num_nodes() == nf2.num_nodes()
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    src, dst = nf1.all_edges()
    src2, dst2 = nf2.all_edges()
    assert F.array_equal(src, src2)
    assert F.array_equal(dst, dst2)

    assert nf1.num_layers == nf2.num_layers
    for i in range(nf1.num_layers):
        assert nf1.layer_size(i) == nf2.layer_size(i)
        assert nf1.layers[i].data.keys() == nf2.layers[i].data.keys()
        for k in nf1.layers[i].data:
            assert F.allclose(nf1.layers[i].data[k], nf2.layers[i].data[k])
    assert nf1.num_blocks == nf2.num_blocks
    for i in range(nf1.num_blocks):
        assert nf1.block_size(i) == nf2.block_size(i)
        assert nf1.blocks[i].data.keys() == nf2.blocks[i].data.keys()
        for k in nf1.blocks[i].data:
            assert F.allclose(nf1.blocks[i].data[k], nf2.blocks[i].data[k])

43

44
def _assert_is_identical_batchedgraph(bg1, bg2):
45
    assert_is_identical(bg1, bg2)
46
47
48
49
    assert bg1.batch_size == bg2.batch_size
    assert bg1.batch_num_nodes == bg2.batch_num_nodes
    assert bg1.batch_num_edges == bg2.batch_num_edges

50

51
def _assert_is_identical_batchedhetero(bg1, bg2):
52
    assert_is_identical_hetero(bg1, bg2)
53
54
55
    for ntype in bg1.ntypes:
        assert bg1.batch_num_nodes(ntype) == bg2.batch_num_nodes(ntype)
    for canonical_etype in bg1.canonical_etypes:
56
57
58
59
        assert bg1.batch_num_edges(canonical_etype) == bg2.batch_num_edges(
            canonical_etype
        )

60

61
62
63
64
def _assert_is_identical_index(i1, i2):
    assert i1.slice_data() == i2.slice_data()
    assert F.array_equal(i1.tousertensor(), i2.tousertensor())

65

Gan Quan's avatar
Gan Quan committed
66
67
68
69
70
71
72
73
74
def _reconstruct_pickle(obj):
    f = io.BytesIO()
    pickle.dump(obj, f)
    f.seek(0)
    obj = pickle.load(f)
    f.close()

    return obj

75

Gan Quan's avatar
Gan Quan committed
76
def test_pickling_index():
77
    # normal index
Gan Quan's avatar
Gan Quan committed
78
79
    i = toindex([1, 2, 3])
    i.tousertensor()
80
    i.todgltensor()  # construct a dgl tensor which is unpicklable
Gan Quan's avatar
Gan Quan committed
81
    i2 = _reconstruct_pickle(i)
82
    _assert_is_identical_index(i, i2)
Gan Quan's avatar
Gan Quan committed
83

84
85
86
87
    # slice index
    i = toindex(slice(5, 10))
    i2 = _reconstruct_pickle(i)
    _assert_is_identical_index(i, i2)
Gan Quan's avatar
Gan Quan committed
88

89

Gan Quan's avatar
Gan Quan committed
90
def test_pickling_graph_index():
91
    gi = create_graph_index(None, False)
Gan Quan's avatar
Gan Quan committed
92
93
94
95
96
97
98
    gi.add_nodes(3)
    src_idx = toindex([0, 0])
    dst_idx = toindex([1, 2])
    gi.add_edges(src_idx, dst_idx)

    gi2 = _reconstruct_pickle(gi)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
99
    assert gi2.num_nodes() == gi.num_nodes()
Gan Quan's avatar
Gan Quan committed
100
    src_idx2, dst_idx2, _ = gi2.edges()
101
102
    assert F.array_equal(src_idx.tousertensor(), src_idx2.tousertensor())
    assert F.array_equal(dst_idx.tousertensor(), dst_idx2.tousertensor())
Gan Quan's avatar
Gan Quan committed
103
104
105


def _global_message_func(nodes):
106
107
    return {"x": nodes.data["x"]}

Gan Quan's avatar
Gan Quan committed
108

109
@unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented")
nv-dlasalle's avatar
nv-dlasalle committed
110
@parametrize_idtype
111
112
113
@pytest.mark.parametrize(
    "g", get_cases(exclude=["dglgraph", "two_hetero_batch"])
)
114
115
def test_pickling_graph(g, idtype):
    g = g.astype(idtype)
116
    new_g = _reconstruct_pickle(g)
117
    check_graph_equal(g, new_g, check_feature=True)
118

119
120

@unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented")
121
122
def test_pickling_batched_heterograph():
    # copied from test_heterograph.create_test_heterograph()
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 2], [1, 0]),
            ("developer", "develops", "game"): ([0, 1], [0, 1]),
        }
    )
    g2 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 2], [1, 0]),
            ("developer", "develops", "game"): ([0, 1], [0, 1]),
        }
    )

    g.nodes["user"].data["u_h"] = F.randn((3, 4))
    g.nodes["game"].data["g_h"] = F.randn((2, 5))
    g.edges["plays"].data["p_h"] = F.randn((4, 6))
    g2.nodes["user"].data["u_h"] = F.randn((3, 4))
    g2.nodes["game"].data["g_h"] = F.randn((2, 5))
    g2.edges["plays"].data["p_h"] = F.randn((4, 6))
146

peizhou001's avatar
peizhou001 committed
147
    bg = dgl.batch([g, g2])
148
    new_bg = _reconstruct_pickle(bg)
149
    check_graph_equal(bg, new_bg)
150

151
152
153
154
155

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="GPU edge_subgraph w/ relabeling not implemented",
)
156
157
158
159
def test_pickling_subgraph():
    f1 = io.BytesIO()
    f2 = io.BytesIO()
    g = dgl.rand_graph(10000, 100000)
160
161
    g.ndata["x"] = F.randn((10000, 4))
    g.edata["x"] = F.randn((100000, 5))
162
163
    pickle.dump(g, f1)
    sg = g.subgraph([0, 1])
164
    sgx = sg.ndata["x"]  # materialize
165
166
167
168
169
170
171
    pickle.dump(sg, f2)
    # TODO(BarclayII): How should I test that the size of the subgraph pickle file should not
    # be as large as the size of the original pickle file?
    assert f1.tell() > f2.tell() * 50

    f2.seek(0)
    f2.truncate()
172
    sgx = sg.edata["x"]  # materialize
173
174
175
176
177
178
    pickle.dump(sg, f2)
    assert f1.tell() > f2.tell() * 50

    f2.seek(0)
    f2.truncate()
    sg = g.edge_subgraph([0])
179
    sgx = sg.edata["x"]  # materialize
180
181
182
183
184
    pickle.dump(sg, f2)
    assert f1.tell() > f2.tell() * 50

    f2.seek(0)
    f2.truncate()
185
    sgx = sg.ndata["x"]  # materialize
186
187
188
189
190
191
    pickle.dump(sg, f2)
    assert f1.tell() > f2.tell() * 50

    f1.close()
    f2.close()

192
193
194
195
196
197

@unittest.skipIf(F._default_context_str != "gpu", reason="Need GPU for pin")
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TensorFlow create graph on gpu when unpickle",
)
nv-dlasalle's avatar
nv-dlasalle committed
198
@parametrize_idtype
199
200
def test_pickling_is_pinned(idtype):
    from copy import deepcopy
201

202
    g = dgl.rand_graph(10, 20, idtype=idtype, device=F.cpu())
203
204
205
206
207
208
209
210
211
212
    hg = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 2], [1, 0]),
            ("developer", "develops", "game"): ([0, 1], [0, 1]),
        },
        idtype=idtype,
        device=F.cpu(),
    )
213
214
215
216
217
218
219
220
221
222
223
224
225
    for graph in [g, hg]:
        assert not graph.is_pinned()
        graph.pin_memory_()
        assert graph.is_pinned()
        pg = _reconstruct_pickle(graph)
        assert pg.is_pinned()
        pg.unpin_memory_()
        dg = deepcopy(graph)
        assert dg.is_pinned()
        dg.unpin_memory_()
        graph.unpin_memory_()


226
if __name__ == "__main__":
Gan Quan's avatar
Gan Quan committed
227
228
229
230
    test_pickling_index()
    test_pickling_graph_index()
    test_pickling_frame()
    test_pickling_graph()
231
    test_pickling_nodeflow()
232
    test_pickling_batched_graph()
233
    test_pickling_heterograph()
234
    test_pickling_batched_heterograph()
235
    test_pickling_is_pinned()