test_pickle.py 3.55 KB
Newer Older
Gan Quan's avatar
Gan Quan committed
1
2
3
4
import dgl
from dgl.frame import Frame, FrameRef, Column
from dgl.graph_index import create_graph_index
from dgl.utils import toindex
5
6
import backend as F
import dgl.function as fn
Gan Quan's avatar
Gan Quan committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import pickle
import io

def _reconstruct_pickle(obj):
    f = io.BytesIO()
    pickle.dump(obj, f)
    f.seek(0)
    obj = pickle.load(f)
    f.close()

    return obj

def test_pickling_index():
    i = toindex([1, 2, 3])
    i.tousertensor()
    i.todgltensor() # construct a dgl tensor which is unpicklable

    i2 = _reconstruct_pickle(i)

26
    assert F.array_equal(i2.tousertensor(), i.tousertensor())
Gan Quan's avatar
Gan Quan committed
27
28
29
30
31
32
33
34
35
36
37
38
39


def test_pickling_graph_index():
    gi = create_graph_index()
    gi.add_nodes(3)
    src_idx = toindex([0, 0])
    dst_idx = toindex([1, 2])
    gi.add_edges(src_idx, dst_idx)

    gi2 = _reconstruct_pickle(gi)

    assert gi2.number_of_nodes() == gi.number_of_nodes()
    src_idx2, dst_idx2, _ = gi2.edges()
40
41
    assert F.array_equal(src_idx.tousertensor(), src_idx2.tousertensor())
    assert F.array_equal(dst_idx.tousertensor(), dst_idx2.tousertensor())
Gan Quan's avatar
Gan Quan committed
42
43
44


def test_pickling_frame():
45
46
    x = F.randn((3, 7))
    y = F.randn((3, 5))
Gan Quan's avatar
Gan Quan committed
47
48
49
50

    c = Column(x)

    c2 = _reconstruct_pickle(c)
51
    assert F.allclose(c.data, c2.data)
Gan Quan's avatar
Gan Quan committed
52
53
54
55

    fr = Frame({'x': x, 'y': y})

    fr2 = _reconstruct_pickle(fr)
56
57
    assert F.allclose(fr2['x'].data, x)
    assert F.allclose(fr2['y'].data, y)
Gan Quan's avatar
Gan Quan committed
58
59
60
61
62
63
64
65

    fr = Frame()


def _assert_is_identical(g, g2):
    assert g.number_of_nodes() == g2.number_of_nodes()
    src, dst = g.all_edges()
    src2, dst2 = g2.all_edges()
66
67
    assert F.array_equal(src, src2)
    assert F.array_equal(dst, dst2)
Gan Quan's avatar
Gan Quan committed
68
69
70
71

    assert len(g.ndata) == len(g2.ndata)
    assert len(g.edata) == len(g2.edata)
    for k in g.ndata:
72
        assert F.allclose(g.ndata[k], g2.ndata[k])
Gan Quan's avatar
Gan Quan committed
73
    for k in g.edata:
74
        assert F.allclose(g.edata[k], g2.edata[k])
Gan Quan's avatar
Gan Quan committed
75
76
77
78
79
80
81
82
83


def _global_message_func(nodes):
    return {'x': nodes.data['x']}

def test_pickling_graph():
    # graph structures and frames are pickled
    g = dgl.DGLGraph()
    g.add_nodes(3)
84
85
    src = F.tensor([0, 0])
    dst = F.tensor([1, 2])
Gan Quan's avatar
Gan Quan committed
86
87
    g.add_edges(src, dst)

88
89
90
91
    x = F.randn((3, 7))
    y = F.randn((3, 5))
    a = F.randn((2, 6))
    b = F.randn((2, 4))
Gan Quan's avatar
Gan Quan committed
92
93
94
95
96
97
98
99

    g.ndata['x'] = x
    g.ndata['y'] = y
    g.edata['a'] = a
    g.edata['b'] = b

    # registered functions are pickled
    g.register_message_func(_global_message_func)
100
    reduce_func = fn.sum('x', 'x')
Gan Quan's avatar
Gan Quan committed
101
102
103
104
105
106
107
108
109
110
111
112
    g.register_reduce_func(reduce_func)

    # custom attributes should be pickled
    g.foo = 2

    new_g = _reconstruct_pickle(g)

    _assert_is_identical(g, new_g)
    assert new_g.foo == 2
    assert new_g._message_func == _global_message_func
    assert isinstance(new_g._reduce_func, type(reduce_func))
    assert new_g._reduce_func._name == 'sum'
113
    assert new_g._reduce_func.reduce_op == F.sum
Gan Quan's avatar
Gan Quan committed
114
115
116
117
118
119
    assert new_g._reduce_func.msg_field == 'x'
    assert new_g._reduce_func.out_field == 'x'

    # test batched graph with partial set case
    g2 = dgl.DGLGraph()
    g2.add_nodes(4)
120
121
    src2 = F.tensor([0, 1])
    dst2 = F.tensor([2, 3])
Gan Quan's avatar
Gan Quan committed
122
123
    g2.add_edges(src2, dst2)

124
125
126
127
    x2 = F.randn((4, 7))
    y2 = F.randn((3, 5))
    a2 = F.randn((2, 6))
    b2 = F.randn((2, 4))
Gan Quan's avatar
Gan Quan committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

    g2.ndata['x'] = x2
    g2.nodes[[0, 1, 3]].data['y'] = y2
    g2.edata['a'] = a2
    g2.edata['b'] = b2

    bg = dgl.batch([g, g2])

    bg2 = _reconstruct_pickle(bg)

    _assert_is_identical(bg, bg2)
    new_g, new_g2 = dgl.unbatch(bg2)
    _assert_is_identical(g, new_g)
    _assert_is_identical(g2, new_g2)


if __name__ == '__main__':
    test_pickling_index()
    test_pickling_graph_index()
    test_pickling_frame()
    test_pickling_graph()