test_basics.py 10.1 KB
Newer Older
1
import torch as th
2
3
from torch.autograd import Variable
import numpy as np
4
5
6
7
8
from dgl.graph import DGLGraph

D = 5
reduce_msg_shapes = set()

9
10
11
12
def check_eq(a, b):
    assert a.shape == b.shape
    assert th.sum(a == b) == int(np.prod(list(a.shape)))

13
14
15
16
17
18
def message_func(src, edge):
    assert len(src['h'].shape) == 2
    assert src['h'].shape[1] == D
    return {'m' : src['h']}

def reduce_func(node, msgs):
Minjie Wang's avatar
Minjie Wang committed
19
20
21
22
23
24
    msgs = msgs['m']
    reduce_msg_shapes.add(tuple(msgs.shape))
    assert len(msgs.shape) == 3
    assert msgs.shape[2] == D
    return {'m' : th.sum(msgs, 1)}

25
26
def apply_node_func(node):
    return {'h' : node['h'] + node['m']}
Minjie Wang's avatar
Minjie Wang committed
27

28
def generate_graph(grad=False):
29
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
30
    g.add_nodes(10) # 10 nodes.
31
32
33
34
35
36
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
Minjie Wang's avatar
Minjie Wang committed
37
38
    ncol = Variable(th.randn(10, D), requires_grad=grad)
    g.set_n_repr({'h' : ncol})
39
40
41
42
43
44
45
46
47
    return g

def test_batch_setter_getter():
    def _pfc(x):
        return list(x.numpy()[:,0])
    g = generate_graph()
    # set all nodes
    g.set_n_repr({'h' : th.zeros((10, D))})
    assert _pfc(g.get_n_repr()['h']) == [0.] * 10
Minjie Wang's avatar
Minjie Wang committed
48
49
50
51
    # pop nodes
    assert _pfc(g.pop_n_repr('h')) == [0.] * 10
    assert len(g.get_n_repr()) == 0
    g.set_n_repr({'h' : th.zeros((10, D))})
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    # set partial nodes
    u = th.tensor([1, 3, 5])
    g.set_n_repr({'h' : th.ones((3, D))}, u)
    assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
    # get partial nodes
    u = th.tensor([1, 2, 3])
    assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]

    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
    0, 2, 2
    2, 9, 3
    0, 3, 4
    3, 9, 5
    0, 4, 6
    4, 9, 7
    0, 5, 8
    5, 9, 9
    0, 6, 10
    6, 9, 11
    0, 7, 12
    7, 9, 13
    0, 8, 14
    8, 9, 15
    9, 0, 16
    '''
    # set all edges
    g.set_e_repr({'l' : th.zeros((17, D))})
    assert _pfc(g.get_e_repr()['l']) == [0.] * 17
Minjie Wang's avatar
Minjie Wang committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    # pop edges
    assert _pfc(g.pop_e_repr('l')) == [0.] * 17
    assert len(g.get_e_repr()) == 0
    g.set_e_repr({'l' : th.zeros((17, D))})
    # set partial edges (many-many)
    u = th.tensor([0, 0, 2, 5, 9])
    v = th.tensor([1, 3, 9, 9, 0])
    g.set_e_repr({'l' : th.ones((5, D))}, u, v)
    truth = [0.] * 17
    truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
    assert _pfc(g.get_e_repr()['l']) == truth
    # set partial edges (many-one)
    u = th.tensor([3, 4, 6])
    v = th.tensor([9])
    g.set_e_repr({'l' : th.ones((3, D))}, u, v)
    truth[5] = truth[7] = truth[11] = 1.
    assert _pfc(g.get_e_repr()['l']) == truth
    # set partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([4, 5, 6])
    g.set_e_repr({'l' : th.ones((3, D))}, u, v)
    truth[6] = truth[8] = truth[10] = 1.
    assert _pfc(g.get_e_repr()['l']) == truth
    # get partial edges (many-many)
    u = th.tensor([0, 6, 0])
    v = th.tensor([6, 9, 7])
    assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
    # get partial edges (many-one)
    u = th.tensor([5, 6, 7])
    v = th.tensor([9])
    assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
    # get partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([3, 4, 5])
    assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.]
118

119
120
121
122
123
124
125
126
127
128
129
130
def test_batch_setter_autograd():
    g = generate_graph(grad=True)
    h1 = g.get_n_repr()['h']
    # partial set
    v = th.tensor([1, 2, 8])
    hh = Variable(th.zeros((len(v), D)), requires_grad=True)
    g.set_n_repr({'h' : hh}, v)
    h2 = g.get_n_repr()['h']
    h2.backward(th.ones((10, D)) * 2)
    check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
    check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))

131
132
133
134
135
def test_batch_send():
    g = generate_graph()
    def _fmsg(src, edge):
        assert src['h'].shape == (5, D)
        return {'m' : src['h']}
Minjie Wang's avatar
Minjie Wang committed
136
    g.register_message_func(_fmsg)
137
    # many-many send
138
139
    u = th.tensor([0, 0, 0, 0, 0])
    v = th.tensor([1, 2, 3, 4, 5])
140
141
    g.send(u, v)
    # one-many send
142
143
    u = th.tensor([0])
    v = th.tensor([1, 2, 3, 4, 5])
144
145
    g.send(u, v)
    # many-one send
146
147
    u = th.tensor([1, 2, 3, 4, 5])
    v = th.tensor([9])
148
    g.send(u, v)
149

150
def test_batch_recv():
Minjie Wang's avatar
Minjie Wang committed
151
    # basic recv test
152
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
153
154
155
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
Minjie Wang's avatar
Minjie Wang committed
156
157
158
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
    reduce_msg_shapes.clear()
159
    g.send(u, v)
Minjie Wang's avatar
Minjie Wang committed
160
161
162
163
    g.recv(th.unique(v))
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

164
165
def test_update_routines():
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
166
167
168
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
169

170
    # send_and_recv
171
172
173
    reduce_msg_shapes.clear()
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
174
    g.send_and_recv(u, v)
175
176
177
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

178
    # pull
179
180
    v = th.tensor([1, 2, 3, 9])
    reduce_msg_shapes.clear()
181
    g.pull(v)
182
183
184
    assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
    reduce_msg_shapes.clear()

185
    # push
186
187
    v = th.tensor([0, 1, 2, 3])
    reduce_msg_shapes.clear()
188
    g.push(v)
189
190
191
192
193
194
195
196
197
    assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
    reduce_msg_shapes.clear()

    # update_all
    reduce_msg_shapes.clear()
    g.update_all()
    assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
    reduce_msg_shapes.clear()

198
199
def test_reduce_0deg():
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
200
    g.add_nodes(5)
201
202
203
204
205
206
207
208
    g.add_edge(1, 0)
    g.add_edge(2, 0)
    g.add_edge(3, 0)
    g.add_edge(4, 0)
    def _message(src, edge):
        return src
    def _reduce(node, msgs):
        assert msgs is not None
209
        return node + msgs.sum(1)
210
211
    old_repr = th.randn(5, 5)
    g.set_n_repr(old_repr)
Minjie Wang's avatar
Minjie Wang committed
212
    g.update_all(_message, _reduce)
213
214
215
216
217
    new_repr = g.get_n_repr()

    assert th.allclose(new_repr[1:], old_repr[1:])
    assert th.allclose(new_repr[0], old_repr.sum(0))

218
def test_pull_0deg():
219
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
220
    g.add_nodes(2)
221
222
223
224
225
226
227
228
229
    g.add_edge(0, 1)
    def _message(src, edge):
        return src
    def _reduce(node, msgs):
        assert msgs is not None
        return msgs.sum(1)

    old_repr = th.randn(2, 5)
    g.set_n_repr(old_repr)
Minjie Wang's avatar
Minjie Wang committed
230
    g.pull(0, _message, _reduce)
231
    new_repr = g.get_n_repr()
232
    assert th.allclose(new_repr[0], old_repr[0])
233
    assert th.allclose(new_repr[1], old_repr[1])
Minjie Wang's avatar
Minjie Wang committed
234
    g.pull(1, _message, _reduce)
235
    new_repr = g.get_n_repr()
236
    assert th.allclose(new_repr[1], old_repr[0])
237
238
239

    old_repr = th.randn(2, 5)
    g.set_n_repr(old_repr)
Minjie Wang's avatar
Minjie Wang committed
240
    g.pull([0, 1], _message, _reduce)
241
    new_repr = g.get_n_repr()
242
    assert th.allclose(new_repr[0], old_repr[0])
243
244
    assert th.allclose(new_repr[1], old_repr[0])

245
246
def _disabled_test_send_twice():
    # TODO(minjie): please re-enable this unittest after the send code problem is fixed.
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    g = DGLGraph()
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(2, 1)
    def _message_a(src, edge):
        return {'a': src['a']}
    def _message_b(src, edge):
        return {'a': src['a'] * 3}
    def _reduce(node, msgs):
        assert msgs is not None
        return {'a': msgs['a'].max(1)[0]}

    old_repr = th.randn(3, 5)
    g.set_n_repr({'a': old_repr})
    g.send(0, 1, _message_a)
    g.send(0, 1, _message_b)
    g.recv([1], _reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], old_repr[0] * 3)

    g.set_n_repr({'a': old_repr})
    g.send(0, 1, _message_a)
    g.send(2, 1, _message_b)
    g.recv([1], _reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])

def test_send_multigraph():
    g = DGLGraph(multigraph=True)
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(2, 1)

    def _message_a(src, edge):
        return {'a': edge['a']}
    def _message_b(src, edge):
        return {'a': edge['a'] * 3}
    def _reduce(node, msgs):
        assert msgs is not None
        return {'a': msgs['a'].max(1)[0]}

    def answer(*args):
        return th.stack(args, 0).max(0)[0]

    # send by eid
    old_repr = th.randn(4, 5)
    g.set_n_repr({'a': th.zeros(3, 5)})
    g.set_e_repr({'a': old_repr})
    g.send(eid=[0, 2], message_func=_message_a)
    g.recv([1], _reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))

    g.set_n_repr({'a': th.zeros(3, 5)})
    g.set_e_repr({'a': old_repr})
    g.send(eid=[0, 2, 3], message_func=_message_a)
    g.recv([1], _reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))

    # send on multigraph
    g.set_n_repr({'a': th.zeros(3, 5)})
    g.set_e_repr({'a': old_repr})
    g.send([0, 2], [1, 1], _message_a)
    g.recv([1], _reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], old_repr.max(0)[0])

    # consecutive send and send_on
    g.set_n_repr({'a': th.zeros(3, 5)})
    g.set_e_repr({'a': old_repr})
    g.send(2, 1, _message_a)
    g.send(eid=[0, 1], message_func=_message_b)
    g.recv([1], _reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))

    # consecutive send_on
    g.set_n_repr({'a': th.zeros(3, 5)})
    g.set_e_repr({'a': old_repr})
    g.send(eid=0, message_func=_message_a)
    g.send(eid=1, message_func=_message_b)
    g.recv([1], _reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))

    # send_and_recv_on
    g.set_n_repr({'a': th.zeros(3, 5)})
    g.set_e_repr({'a': old_repr})
    g.send_and_recv(eid=[0, 2, 3], message_func=_message_a, reduce_func=_reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
    assert th.allclose(new_repr[[0, 2]], th.zeros(2, 5))


344
345
if __name__ == '__main__':
    test_batch_setter_getter()
346
    test_batch_setter_autograd()
347
    test_batch_send()
348
    test_batch_recv()
349
    test_update_routines()
350
    test_reduce_0deg()
351
    test_pull_0deg()
352
353
    test_send_twice()
    test_send_multigraph()