test_basics.py 15.5 KB
Newer Older
1
import torch as th
2
3
from torch.autograd import Variable
import numpy as np
4
import dgl
5
from dgl.graph import DGLGraph
6
import utils as U
7
8
9
10

D = 5
reduce_msg_shapes = set()

11
12
13
14
def check_eq(a, b):
    assert a.shape == b.shape
    assert th.sum(a == b) == int(np.prod(list(a.shape)))

15
16
17
18
def message_func(edges):
    assert len(edges.src['h'].shape) == 2
    assert edges.src['h'].shape[1] == D
    return {'m' : edges.src['h']}
19

20
21
def reduce_func(nodes):
    msgs = nodes.mailbox['m']
Minjie Wang's avatar
Minjie Wang committed
22
23
24
    reduce_msg_shapes.add(tuple(msgs.shape))
    assert len(msgs.shape) == 3
    assert msgs.shape[2] == D
Minjie Wang's avatar
Minjie Wang committed
25
    return {'accum' : th.sum(msgs, 1)}
Minjie Wang's avatar
Minjie Wang committed
26

27
28
def apply_node_func(nodes):
    return {'h' : nodes.data['h'] + nodes.data['accum']}
Minjie Wang's avatar
Minjie Wang committed
29

30
def generate_graph(grad=False):
31
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
32
    g.add_nodes(10) # 10 nodes.
33
    # create a graph where 0 is the source and 9 is the sink
Minjie Wang's avatar
Minjie Wang committed
34
    # 17 edges
35
36
37
38
39
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
Minjie Wang's avatar
Minjie Wang committed
40
    ncol = Variable(th.randn(10, D), requires_grad=grad)
Minjie Wang's avatar
Minjie Wang committed
41
    ecol = Variable(th.randn(17, D), requires_grad=grad)
42
43
    g.ndata['h'] = ncol
    g.edata['w'] = ecol
44
45
    g.set_n_initializer(dgl.init.zero_initializer)
    g.set_e_initializer(dgl.init.zero_initializer)
46
47
48
49
50
51
52
    return g

def test_batch_setter_getter():
    def _pfc(x):
        return list(x.numpy()[:,0])
    g = generate_graph()
    # set all nodes
53
    g.ndata['h'] = th.zeros((10, D))
54
    assert U.allclose(g.ndata['h'], th.zeros((10, D)))
Minjie Wang's avatar
Minjie Wang committed
55
    # pop nodes
56
    old_len = len(g.ndata)
Minjie Wang's avatar
Minjie Wang committed
57
    assert _pfc(g.pop_n_repr('h')) == [0.] * 10
58
59
    assert len(g.ndata) == old_len - 1
    g.ndata['h'] = th.zeros((10, D))
60
61
    # set partial nodes
    u = th.tensor([1, 3, 5])
62
63
    g.nodes[u].data['h'] = th.ones((3, D))
    assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
64
65
    # get partial nodes
    u = th.tensor([1, 2, 3])
66
    assert _pfc(g.nodes[u].data['h']) == [1., 0., 1.]
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
    0, 2, 2
    2, 9, 3
    0, 3, 4
    3, 9, 5
    0, 4, 6
    4, 9, 7
    0, 5, 8
    5, 9, 9
    0, 6, 10
    6, 9, 11
    0, 7, 12
    7, 9, 13
    0, 8, 14
    8, 9, 15
    9, 0, 16
    '''
    # set all edges
89
90
    g.edata['l'] = th.zeros((17, D))
    assert _pfc(g.edata['l']) == [0.] * 17
Minjie Wang's avatar
Minjie Wang committed
91
    # pop edges
92
    old_len = len(g.edata)
Minjie Wang's avatar
Minjie Wang committed
93
    assert _pfc(g.pop_e_repr('l')) == [0.] * 17
94
95
    assert len(g.edata) == old_len - 1
    g.edata['l'] = th.zeros((17, D))
Minjie Wang's avatar
Minjie Wang committed
96
97
98
    # set partial edges (many-many)
    u = th.tensor([0, 0, 2, 5, 9])
    v = th.tensor([1, 3, 9, 9, 0])
99
    g.edges[u, v].data['l'] = th.ones((5, D))
Minjie Wang's avatar
Minjie Wang committed
100
101
    truth = [0.] * 17
    truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
102
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
103
104
105
    # set partial edges (many-one)
    u = th.tensor([3, 4, 6])
    v = th.tensor([9])
106
    g.edges[u, v].data['l'] = th.ones((3, D))
Minjie Wang's avatar
Minjie Wang committed
107
    truth[5] = truth[7] = truth[11] = 1.
108
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
109
110
111
    # set partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([4, 5, 6])
112
    g.edges[u, v].data['l'] = th.ones((3, D))
Minjie Wang's avatar
Minjie Wang committed
113
    truth[6] = truth[8] = truth[10] = 1.
114
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
115
116
117
    # get partial edges (many-many)
    u = th.tensor([0, 6, 0])
    v = th.tensor([6, 9, 7])
118
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Minjie Wang's avatar
Minjie Wang committed
119
120
121
    # get partial edges (many-one)
    u = th.tensor([5, 6, 7])
    v = th.tensor([9])
122
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Minjie Wang's avatar
Minjie Wang committed
123
124
125
    # get partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([3, 4, 5])
126
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
127

128
129
def test_batch_setter_autograd():
    g = generate_graph(grad=True)
130
    h1 = g.ndata['h']
131
132
133
    # partial set
    v = th.tensor([1, 2, 8])
    hh = Variable(th.zeros((len(v), D)), requires_grad=True)
134
135
    g.nodes[v].data['h'] = hh
    h2 = g.ndata['h']
136
137
138
139
    h2.backward(th.ones((10, D)) * 2)
    check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
    check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))

140
141
def test_batch_send():
    g = generate_graph()
142
143
144
    def _fmsg(edges):
        assert edges.src['h'].shape == (5, D)
        return {'m' : edges.src['h']}
Minjie Wang's avatar
Minjie Wang committed
145
    g.register_message_func(_fmsg)
146
    # many-many send
147
148
    u = th.tensor([0, 0, 0, 0, 0])
    v = th.tensor([1, 2, 3, 4, 5])
149
    g.send((u, v))
150
    # one-many send
151
152
    u = th.tensor([0])
    v = th.tensor([1, 2, 3, 4, 5])
153
    g.send((u, v))
154
    # many-one send
155
156
    u = th.tensor([1, 2, 3, 4, 5])
    v = th.tensor([9])
157
    g.send((u, v))
158

159
def test_batch_recv():
Minjie Wang's avatar
Minjie Wang committed
160
    # basic recv test
161
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
162
163
164
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
Minjie Wang's avatar
Minjie Wang committed
165
166
167
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
    reduce_msg_shapes.clear()
168
    g.send((u, v))
Minjie Wang's avatar
Minjie Wang committed
169
170
171
172
    g.recv(th.unique(v))
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

173
174
175
176
177
178
179
180
181
182
183
184
def test_apply_nodes():
    def _upd(nodes):
        return {'h' : nodes.data['h'] * 2}
    g = generate_graph()
    g.register_apply_node_func(_upd)
    old = g.ndata['h']
    g.apply_nodes()
    assert U.allclose(old * 2, g.ndata['h'])
    u = th.tensor([0, 3, 4, 6])
    g.apply_nodes(lambda nodes : {'h' : nodes.data['h'] * 0.}, u)
    assert U.allclose(g.ndata['h'][u], th.zeros((4, D)))

185
def test_apply_edges():
186
187
188
    def _upd(edges):
        return {'w' : edges.data['w'] * 2}
    g = generate_graph()
189
    g.register_apply_edge_func(_upd)
190
    old = g.edata['w']
191
    g.apply_edges()
192
    assert U.allclose(old * 2, g.edata['w'])
193
194
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
195
    g.apply_edges(lambda edges : {'w' : edges.data['w'] * 0.}, (u, v))
196
    eid = g.edge_ids(u, v)
197
    assert U.allclose(g.edata['w'][eid], th.zeros((6, D)))
198

199
200
def test_update_routines():
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
201
202
203
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
204

205
    # send_and_recv
206
    reduce_msg_shapes.clear()
207
208
    u = [0, 0, 0, 4, 5, 6]
    v = [1, 2, 3, 9, 9, 9]
209
    g.send_and_recv((u, v))
210
211
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()
212
213
214
    try:
        g.send_and_recv([u, v])
        assert False
215
    except:
216
        pass
217

218
    # pull
219
220
    v = th.tensor([1, 2, 3, 9])
    reduce_msg_shapes.clear()
221
    g.pull(v)
222
223
224
    assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
    reduce_msg_shapes.clear()

225
    # push
226
227
    v = th.tensor([0, 1, 2, 3])
    reduce_msg_shapes.clear()
228
    g.push(v)
229
230
231
232
233
234
235
236
237
    assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
    reduce_msg_shapes.clear()

    # update_all
    reduce_msg_shapes.clear()
    g.update_all()
    assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
    reduce_msg_shapes.clear()

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def test_recv_0deg():
    # test recv with 0deg nodes;
    g = DGLGraph()
    g.add_nodes(2)
    g.add_edge(0, 1)
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
    def _apply(nodes):
        return {'h' : nodes.data['h'] * 2}
    def _init2(shape, dtype, ctx, ids):
        return 2 + th.zeros(shape, dtype=dtype, device=ctx)
    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    g.set_n_initializer(_init2, 'h')
    # test#1: recv both 0deg and non-0deg nodes
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.send((0, 1))
    g.recv([0, 1])
    new = g.ndata.pop('h')
    # 0deg check: initialized with the func and got applied
    assert U.allclose(new[0], th.full((5,), 4))
    # non-0deg check
    assert U.allclose(new[1], th.sum(old, 0) * 2)

    # test#2: recv only 0deg node is equal to apply
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.send((0, 1))
    g.recv(0)
    new = g.ndata.pop('h')
    # 0deg check: equal to apply_nodes
    assert U.allclose(new[0], 2 * old[0])
    # non-0deg check: untouched
    assert U.allclose(new[1], old[1])

def test_recv_0deg_newfld():
    # test recv with 0deg nodes; the reducer also creates a new field
    g = DGLGraph()
    g.add_nodes(2)
    g.add_edge(0, 1)
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h1' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
    def _apply(nodes):
        return {'h1' : nodes.data['h1'] * 2}
    def _init2(shape, dtype, ctx, ids):
        return 2 + th.zeros(shape, dtype=dtype, device=ctx)
    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    # test#1: recv both 0deg and non-0deg nodes
    old = th.randn((2, 5))
    g.set_n_initializer(_init2, 'h1')
    g.ndata['h'] = old
    g.send((0, 1))
    g.recv([0, 1])
    new = g.ndata.pop('h1')
    # 0deg check: initialized with the func and got applied
    assert U.allclose(new[0], th.full((5,), 4))
    # non-0deg check
    assert U.allclose(new[1], th.sum(old, 0) * 2)

    # test#2: recv only 0deg node
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.ndata['h1'] = th.full((2, 5), -1)  # this is necessary
    g.send((0, 1))
    g.recv(0)
    new = g.ndata.pop('h1')
    # 0deg check: fallback to apply
    assert U.allclose(new[0], th.full((5,), -2))
    # non-0deg check: not changed
    assert U.allclose(new[1], th.full((5,), -1))

def test_update_all_0deg():
    # test#1
319
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
320
    g.add_nodes(5)
321
322
323
324
    g.add_edge(1, 0)
    g.add_edge(2, 0)
    g.add_edge(3, 0)
    g.add_edge(4, 0)
325
326
327
328
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
329
330
    def _apply(nodes):
        return {'h' : nodes.data['h'] * 2}
331
332
333
    def _init2(shape, dtype, ctx, ids):
        return 2 + th.zeros(shape, dtype=dtype, device=ctx)
    g.set_n_initializer(_init2, 'h')
334
    old_repr = th.randn(5, 5)
335
    g.ndata['h'] = old_repr
336
    g.update_all(_message, _reduce, _apply)
337
    new_repr = g.ndata['h']
338
339
    # the first row of the new_repr should be the sum of all the node
    # features; while the 0-deg nodes should be initialized by the
340
341
342
343
344
345
346
347
348
349
350
351
352
    # initializer and applied with UDF.
    assert U.allclose(new_repr[1:], 2*(2+th.zeros((4,5))))
    assert U.allclose(new_repr[0], 2 * old_repr.sum(0))
    
    # test#2: graph with no edge
    g = DGLGraph()
    g.add_nodes(5)
    g.set_n_initializer(_init2, 'h')
    g.ndata['h'] = old_repr
    g.update_all(_message, _reduce, _apply)
    new_repr = g.ndata['h']
    # should fallback to apply
    assert U.allclose(new_repr, 2*old_repr)
353

354
def test_pull_0deg():
355
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
356
    g.add_nodes(2)
357
    g.add_edge(0, 1)
358
359
360
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
        return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
    def _apply(nodes):
        return {'h' : nodes.data['h'] * 2}
    def _init2(shape, dtype, ctx, ids):
        return 2 + th.zeros(shape, dtype=dtype, device=ctx)
    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    g.set_n_initializer(_init2, 'h')
    # test#1: pull both 0deg and non-0deg nodes
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.pull([0, 1])
    new = g.ndata.pop('h')
    # 0deg check: initialized with the func and got applied
    assert U.allclose(new[0], th.full((5,), 4))
    # non-0deg check
    assert U.allclose(new[1], th.sum(old, 0) * 2)

    # test#2: pull only 0deg node
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.pull(0)
    new = g.ndata.pop('h')
    # 0deg check: fallback to apply
    assert U.allclose(new[0], 2*old[0])
    # non-0deg check: not touched
    assert U.allclose(new[1], old[1])
389

390
391
def _disabled_test_send_twice():
    # TODO(minjie): please re-enable this unittest after the send code problem is fixed.
392
393
394
395
    g = DGLGraph()
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(2, 1)
396
397
398
399
400
401
    def _message_a(edges):
        return {'a': edges.src['a']}
    def _message_b(edges):
        return {'a': edges.src['a'] * 3}
    def _reduce(nodes):
        return {'a': nodes.mailbox['a'].max(1)[0]}
402
403

    old_repr = th.randn(3, 5)
404
405
406
407
408
    g.ndata['a'] = old_repr
    g.send((0, 1), _message_a)
    g.send((0, 1), _message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
409
    assert U.allclose(new_repr[1], old_repr[0] * 3)
410

411
412
413
414
415
    g.ndata['a'] = old_repr
    g.send((0, 1), _message_a)
    g.send((2, 1), _message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
416
    assert U.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])
417
418
419
420
421
422
423
424
425

def test_send_multigraph():
    g = DGLGraph(multigraph=True)
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(2, 1)

426
427
428
429
430
431
    def _message_a(edges):
        return {'a': edges.data['a']}
    def _message_b(edges):
        return {'a': edges.data['a'] * 3}
    def _reduce(nodes):
        return {'a': nodes.mailbox['a'].max(1)[0]}
432
433
434
435
436
437

    def answer(*args):
        return th.stack(args, 0).max(0)[0]

    # send by eid
    old_repr = th.randn(4, 5)
438
439
440
441
442
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send([0, 2], message_func=_message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
443
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))
444

445
446
447
448
449
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send([0, 2, 3], message_func=_message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
450
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
451
452

    # send on multigraph
453
454
455
456
457
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send(([0, 2], [1, 1]), _message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
458
    assert U.allclose(new_repr[1], old_repr.max(0)[0])
459
460

    # consecutive send and send_on
461
462
463
464
465
466
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send((2, 1), _message_a)
    g.send([0, 1], message_func=_message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
467
    assert U.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))
468
469

    # consecutive send_on
470
471
472
473
474
475
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send(0, message_func=_message_a)
    g.send(1, message_func=_message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
476
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))
477
478

    # send_and_recv_on
479
480
481
482
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send_and_recv([0, 2, 3], message_func=_message_a, reduce_func=_reduce)
    new_repr = g.ndata['a']
483
484
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
    assert U.allclose(new_repr[[0, 2]], th.zeros(2, 5))
485

486
487
488
489
490
491
492
493
def test_dynamic_addition():
    N = 3
    D = 1

    g = DGLGraph()

    # Test node addition
    g.add_nodes(N)
494
495
    g.ndata.update({'h1': th.randn(N, D),
                    'h2': th.randn(N, D)})
496
    g.add_nodes(3)
497
    assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3
498
499
500
501

    # Test edge addition
    g.add_edge(0, 1)
    g.add_edge(1, 0)
502
503
504
    g.edata.update({'h1': th.randn(2, D),
                    'h2': th.randn(2, D)})
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2
505
506

    g.add_edges([0, 2], [2, 0])
507
508
    g.edata['h1'] = th.randn(4, D)
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4
509
510

    g.add_edge(1, 2)
511
512
    g.edges[4].data['h1'] = th.randn(1, D)
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5
513

514

515
516
if __name__ == '__main__':
    test_batch_setter_getter()
517
    test_batch_setter_autograd()
518
    test_batch_send()
519
    test_batch_recv()
520
    test_apply_nodes()
521
    test_apply_edges()
522
    test_update_routines()
523
524
525
    test_recv_0deg()
    test_recv_0deg_newfld()
    test_update_all_0deg()
526
    test_pull_0deg()
527
    test_send_multigraph()
528
    test_dynamic_addition()