test_pickle.py 10.9 KB
Newer Older
1
import networkx as nx
2
import scipy.sparse as ssp
Gan Quan's avatar
Gan Quan committed
3
import dgl
4
import dgl.contrib as contrib
Gan Quan's avatar
Gan Quan committed
5
6
7
from dgl.frame import Frame, FrameRef, Column
from dgl.graph_index import create_graph_index
from dgl.utils import toindex
8
9
import backend as F
import dgl.function as fn
Gan Quan's avatar
Gan Quan committed
10
11
import pickle
import io
12
import unittest
Gan Quan's avatar
Gan Quan committed
13

14
15
16
def _assert_is_identical(g, g2):
    assert g.is_readonly == g2.is_readonly
    assert g.number_of_nodes() == g2.number_of_nodes()
17
18
    src, dst = g.all_edges(order='eid')
    src2, dst2 = g2.all_edges(order='eid')
19
20
21
22
23
24
25
26
27
28
    assert F.array_equal(src, src2)
    assert F.array_equal(dst, dst2)

    assert len(g.ndata) == len(g2.ndata)
    assert len(g.edata) == len(g2.edata)
    for k in g.ndata:
        assert F.allclose(g.ndata[k], g2.ndata[k])
    for k in g.edata:
        assert F.allclose(g.edata[k], g2.edata[k])

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def _assert_is_identical_hetero(g, g2):
    assert g.is_readonly == g2.is_readonly
    assert g.ntypes == g2.ntypes
    assert g.canonical_etypes == g2.canonical_etypes

    # check if two metagraphs are identical
    for edges, features in g.metagraph.edges(keys=True).items():
        assert g2.metagraph.edges(keys=True)[edges] == features

    # check if node ID spaces and feature spaces are equal
    for ntype in g.ntypes:
        assert g.number_of_nodes(ntype) == g2.number_of_nodes(ntype)
        assert len(g.nodes[ntype].data) == len(g2.nodes[ntype].data)
        for k in g.nodes[ntype].data:
            assert F.allclose(g.nodes[ntype].data[k], g2.nodes[ntype].data[k])

    # check if edge ID spaces and feature spaces are equal
    for etype in g.canonical_etypes:
        src, dst = g.all_edges(etype=etype, order='eid')
        src2, dst2 = g2.all_edges(etype=etype, order='eid')
        assert F.array_equal(src, src2)
        assert F.array_equal(dst, dst2)
        for k in g.edges[etype].data:
            assert F.allclose(g.edges[etype].data[k], g2.edges[etype].data[k])

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def _assert_is_identical_nodeflow(nf1, nf2):
    assert nf1.is_readonly == nf2.is_readonly
    assert nf1.number_of_nodes() == nf2.number_of_nodes()
    src, dst = nf1.all_edges()
    src2, dst2 = nf2.all_edges()
    assert F.array_equal(src, src2)
    assert F.array_equal(dst, dst2)

    assert nf1.num_layers == nf2.num_layers
    for i in range(nf1.num_layers):
        assert nf1.layer_size(i) == nf2.layer_size(i)
        assert nf1.layers[i].data.keys() == nf2.layers[i].data.keys()
        for k in nf1.layers[i].data:
            assert F.allclose(nf1.layers[i].data[k], nf2.layers[i].data[k])
    assert nf1.num_blocks == nf2.num_blocks
    for i in range(nf1.num_blocks):
        assert nf1.block_size(i) == nf2.block_size(i)
        assert nf1.blocks[i].data.keys() == nf2.blocks[i].data.keys()
        for k in nf1.blocks[i].data:
            assert F.allclose(nf1.blocks[i].data[k], nf2.blocks[i].data[k])

def _assert_is_identical_batchedgraph(bg1, bg2):
    _assert_is_identical(bg1, bg2)
    assert bg1.batch_size == bg2.batch_size
    assert bg1.batch_num_nodes == bg2.batch_num_nodes
    assert bg1.batch_num_edges == bg2.batch_num_edges

81
82
83
84
85
86
87
def _assert_is_identical_batchedhetero(bg1, bg2):
    _assert_is_identical_hetero(bg1, bg2)
    for ntype in bg1.ntypes:
        assert bg1.batch_num_nodes(ntype) == bg2.batch_num_nodes(ntype)
    for canonical_etype in bg1.canonical_etypes:
        assert bg1.batch_num_edges(canonical_etype) == bg2.batch_num_edges(canonical_etype)

88
89
90
91
def _assert_is_identical_index(i1, i2):
    assert i1.slice_data() == i2.slice_data()
    assert F.array_equal(i1.tousertensor(), i2.tousertensor())

Gan Quan's avatar
Gan Quan committed
92
93
94
95
96
97
98
99
100
101
def _reconstruct_pickle(obj):
    f = io.BytesIO()
    pickle.dump(obj, f)
    f.seek(0)
    obj = pickle.load(f)
    f.close()

    return obj

def test_pickling_index():
102
    # normal index
Gan Quan's avatar
Gan Quan committed
103
104
105
106
    i = toindex([1, 2, 3])
    i.tousertensor()
    i.todgltensor() # construct a dgl tensor which is unpicklable
    i2 = _reconstruct_pickle(i)
107
    _assert_is_identical_index(i, i2)
Gan Quan's avatar
Gan Quan committed
108

109
110
111
112
    # slice index
    i = toindex(slice(5, 10))
    i2 = _reconstruct_pickle(i)
    _assert_is_identical_index(i, i2)
Gan Quan's avatar
Gan Quan committed
113
114

def test_pickling_graph_index():
115
    gi = create_graph_index(None, False)
Gan Quan's avatar
Gan Quan committed
116
117
118
119
120
121
122
123
124
    gi.add_nodes(3)
    src_idx = toindex([0, 0])
    dst_idx = toindex([1, 2])
    gi.add_edges(src_idx, dst_idx)

    gi2 = _reconstruct_pickle(gi)

    assert gi2.number_of_nodes() == gi.number_of_nodes()
    src_idx2, dst_idx2, _ = gi2.edges()
125
126
    assert F.array_equal(src_idx.tousertensor(), src_idx2.tousertensor())
    assert F.array_equal(dst_idx.tousertensor(), dst_idx2.tousertensor())
Gan Quan's avatar
Gan Quan committed
127
128
129


def test_pickling_frame():
130
131
    x = F.randn((3, 7))
    y = F.randn((3, 5))
Gan Quan's avatar
Gan Quan committed
132
133
134
135

    c = Column(x)

    c2 = _reconstruct_pickle(c)
136
    assert F.allclose(c.data, c2.data)
Gan Quan's avatar
Gan Quan committed
137
138
139
140

    fr = Frame({'x': x, 'y': y})

    fr2 = _reconstruct_pickle(fr)
141
142
    assert F.allclose(fr2['x'].data, x)
    assert F.allclose(fr2['y'].data, y)
Gan Quan's avatar
Gan Quan committed
143
144
145
146
147
148
149
150
151
152
153

    fr = Frame()


def _global_message_func(nodes):
    return {'x': nodes.data['x']}

def test_pickling_graph():
    # graph structures and frames are pickled
    g = dgl.DGLGraph()
    g.add_nodes(3)
154
155
    src = F.tensor([0, 0])
    dst = F.tensor([1, 2])
Gan Quan's avatar
Gan Quan committed
156
157
    g.add_edges(src, dst)

158
159
160
161
    x = F.randn((3, 7))
    y = F.randn((3, 5))
    a = F.randn((2, 6))
    b = F.randn((2, 4))
Gan Quan's avatar
Gan Quan committed
162
163
164
165
166
167
168
169

    g.ndata['x'] = x
    g.ndata['y'] = y
    g.edata['a'] = a
    g.edata['b'] = b

    # registered functions are pickled
    g.register_message_func(_global_message_func)
170
    reduce_func = fn.sum('x', 'x')
Gan Quan's avatar
Gan Quan committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    g.register_reduce_func(reduce_func)

    # custom attributes should be pickled
    g.foo = 2

    new_g = _reconstruct_pickle(g)

    _assert_is_identical(g, new_g)
    assert new_g.foo == 2
    assert new_g._message_func == _global_message_func
    assert isinstance(new_g._reduce_func, type(reduce_func))
    assert new_g._reduce_func._name == 'sum'
    assert new_g._reduce_func.msg_field == 'x'
    assert new_g._reduce_func.out_field == 'x'

    # test batched graph with partial set case
    g2 = dgl.DGLGraph()
    g2.add_nodes(4)
189
190
    src2 = F.tensor([0, 1])
    dst2 = F.tensor([2, 3])
Gan Quan's avatar
Gan Quan committed
191
192
    g2.add_edges(src2, dst2)

193
194
195
196
    x2 = F.randn((4, 7))
    y2 = F.randn((3, 5))
    a2 = F.randn((2, 6))
    b2 = F.randn((2, 4))
Gan Quan's avatar
Gan Quan committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

    g2.ndata['x'] = x2
    g2.nodes[[0, 1, 3]].data['y'] = y2
    g2.edata['a'] = a2
    g2.edata['b'] = b2

    bg = dgl.batch([g, g2])

    bg2 = _reconstruct_pickle(bg)

    _assert_is_identical(bg, bg2)
    new_g, new_g2 = dgl.unbatch(bg2)
    _assert_is_identical(g, new_g)
    _assert_is_identical(g2, new_g2)

212
213
214
215
216
217
    # readonly graph
    g = dgl.DGLGraph([(0, 1), (1, 2)], readonly=True)
    new_g = _reconstruct_pickle(g)
    _assert_is_identical(g, new_g)

    # multigraph
218
    g = dgl.DGLGraph([(0, 1), (0, 1), (1, 2)])
219
220
221
222
    new_g = _reconstruct_pickle(g)
    _assert_is_identical(g, new_g)

    # readonly multigraph
223
    g = dgl.DGLGraph([(0, 1), (0, 1), (1, 2)], readonly=True)
224
225
226
    new_g = _reconstruct_pickle(g)
    _assert_is_identical(g, new_g)

227
228
229
230
231
232
233
234
235
def test_pickling_nodeflow():
    elist = [(0, 1), (1, 2), (2, 3), (3, 0)]
    g = dgl.DGLGraph(elist, readonly=True)
    g.ndata['x'] = F.randn((4, 5))
    g.edata['y'] = F.randn((4, 3))
    nf = contrib.sampling.sampler.create_full_nodeflow(g, 5)
    nf.copy_from_parent()  # add features
    new_nf = _reconstruct_pickle(nf)
    _assert_is_identical_nodeflow(nf, new_nf)
Gan Quan's avatar
Gan Quan committed
236

237
238
239
240
241
242
243
244
245
def test_pickling_batched_graph():
    glist = [nx.path_graph(i + 5) for i in range(5)]
    glist = [dgl.DGLGraph(g) for g in glist]
    bg = dgl.batch(glist)
    bg.ndata['x'] = F.randn((35, 5))
    bg.edata['y'] = F.randn((60, 3))
    new_bg = _reconstruct_pickle(bg)
    _assert_is_identical_batchedgraph(bg, new_bg)

246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def test_pickling_heterograph():
    # copied from test_heterograph.create_test_heterograph()
    plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1])))
    wishes_nx = nx.DiGraph()
    wishes_nx.add_nodes_from(['u0', 'u1', 'u2'], bipartite=0)
    wishes_nx.add_nodes_from(['g0', 'g1'], bipartite=1)
    wishes_nx.add_edge('u0', 'g1', id=0)
    wishes_nx.add_edge('u2', 'g0', id=1)

    follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
    plays_g = dgl.bipartite(plays_spmat, 'user', 'plays', 'game')
    wishes_g = dgl.bipartite(wishes_nx, 'user', 'wishes', 'game')
    develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
    g = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])

    g.nodes['user'].data['u_h'] = F.randn((3, 4))
    g.nodes['game'].data['g_h'] = F.randn((2, 5))
    g.edges['plays'].data['p_h'] = F.randn((4, 6))

    new_g = _reconstruct_pickle(g)
    _assert_is_identical_hetero(g, new_g)

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def test_pickling_batched_heterograph():
    # copied from test_heterograph.create_test_heterograph()
    plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1])))
    wishes_nx = nx.DiGraph()
    wishes_nx.add_nodes_from(['u0', 'u1', 'u2'], bipartite=0)
    wishes_nx.add_nodes_from(['g0', 'g1'], bipartite=1)
    wishes_nx.add_edge('u0', 'g1', id=0)
    wishes_nx.add_edge('u2', 'g0', id=1)

    follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
    plays_g = dgl.bipartite(plays_spmat, 'user', 'plays', 'game')
    wishes_g = dgl.bipartite(wishes_nx, 'user', 'wishes', 'game')
    develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
    g = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])
    g2 = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])

    g.nodes['user'].data['u_h'] = F.randn((3, 4))
    g.nodes['game'].data['g_h'] = F.randn((2, 5))
    g.edges['plays'].data['p_h'] = F.randn((4, 6))
    g2.nodes['user'].data['u_h'] = F.randn((3, 4))
    g2.nodes['game'].data['g_h'] = F.randn((2, 5))
    g2.edges['plays'].data['p_h'] = F.randn((4, 6))

    bg = dgl.batch_hetero([g, g2])
    new_bg = _reconstruct_pickle(bg)
    _assert_is_identical_batchedhetero(bg, new_bg)

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
@unittest.skipIf(dgl.backend.backend_name != "pytorch", reason="Only test for pytorch format file")
def test_pickling_heterograph_index_compatibility():
    plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1])))
    wishes_nx = nx.DiGraph()
    wishes_nx.add_nodes_from(['u0', 'u1', 'u2'], bipartite=0)
    wishes_nx.add_nodes_from(['g0', 'g1'], bipartite=1)
    wishes_nx.add_edge('u0', 'g1', id=0)
    wishes_nx.add_edge('u2', 'g0', id=1)

    follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
    plays_g = dgl.bipartite(plays_spmat, 'user', 'plays', 'game')
    wishes_g = dgl.bipartite(wishes_nx, 'user', 'wishes', 'game')
    develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
    g = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])

    with open("tests/compute/hetero_pickle_old.pkl", "rb") as f:
        gi = pickle.load(f)
        f.close()
    new_g = dgl.DGLHeteroGraph(gi, g.ntypes, g.etypes)
    _assert_is_identical_hetero(g, new_g)
315

Gan Quan's avatar
Gan Quan committed
316
317
318
319
320
if __name__ == '__main__':
    test_pickling_index()
    test_pickling_graph_index()
    test_pickling_frame()
    test_pickling_graph()
321
    test_pickling_nodeflow()
322
    test_pickling_batched_graph()
323
    test_pickling_heterograph()
324
    test_pickling_batched_heterograph()