"notebooks/vscode:/vscode.git/clone" did not exist on "1a6dd73552c0825e5b6216261df43ff4122824eb"
test_basics.py 18 KB
Newer Older
1
import torch as th
2
3
from torch.autograd import Variable
import numpy as np
4
import dgl
5
from dgl.graph import DGLGraph
6
import utils as U
7
from collections import defaultdict as ddict
8
9
10
11

D = 5
reduce_msg_shapes = set()

12
13
14
15
def check_eq(a, b):
    assert a.shape == b.shape
    assert th.sum(a == b) == int(np.prod(list(a.shape)))

16
17
18
19
def message_func(edges):
    assert len(edges.src['h'].shape) == 2
    assert edges.src['h'].shape[1] == D
    return {'m' : edges.src['h']}
20

21
22
def reduce_func(nodes):
    msgs = nodes.mailbox['m']
Minjie Wang's avatar
Minjie Wang committed
23
24
25
    reduce_msg_shapes.add(tuple(msgs.shape))
    assert len(msgs.shape) == 3
    assert msgs.shape[2] == D
Minjie Wang's avatar
Minjie Wang committed
26
    return {'accum' : th.sum(msgs, 1)}
Minjie Wang's avatar
Minjie Wang committed
27

28
29
def apply_node_func(nodes):
    return {'h' : nodes.data['h'] + nodes.data['accum']}
Minjie Wang's avatar
Minjie Wang committed
30

31
def generate_graph(grad=False):
32
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
33
    g.add_nodes(10) # 10 nodes.
34
    # create a graph where 0 is the source and 9 is the sink
Minjie Wang's avatar
Minjie Wang committed
35
    # 17 edges
36
37
38
39
40
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
Minjie Wang's avatar
Minjie Wang committed
41
    ncol = Variable(th.randn(10, D), requires_grad=grad)
Minjie Wang's avatar
Minjie Wang committed
42
    ecol = Variable(th.randn(17, D), requires_grad=grad)
43
44
    g.ndata['h'] = ncol
    g.edata['w'] = ecol
45
46
    g.set_n_initializer(dgl.init.zero_initializer)
    g.set_e_initializer(dgl.init.zero_initializer)
47
48
49
50
51
52
53
    return g

def test_batch_setter_getter():
    def _pfc(x):
        return list(x.numpy()[:,0])
    g = generate_graph()
    # set all nodes
54
    g.ndata['h'] = th.zeros((10, D))
55
    assert U.allclose(g.ndata['h'], th.zeros((10, D)))
Minjie Wang's avatar
Minjie Wang committed
56
    # pop nodes
57
    old_len = len(g.ndata)
Minjie Wang's avatar
Minjie Wang committed
58
    assert _pfc(g.pop_n_repr('h')) == [0.] * 10
59
60
    assert len(g.ndata) == old_len - 1
    g.ndata['h'] = th.zeros((10, D))
61
62
    # set partial nodes
    u = th.tensor([1, 3, 5])
63
64
    g.nodes[u].data['h'] = th.ones((3, D))
    assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
65
66
    # get partial nodes
    u = th.tensor([1, 2, 3])
67
    assert _pfc(g.nodes[u].data['h']) == [1., 0., 1.]
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
    0, 2, 2
    2, 9, 3
    0, 3, 4
    3, 9, 5
    0, 4, 6
    4, 9, 7
    0, 5, 8
    5, 9, 9
    0, 6, 10
    6, 9, 11
    0, 7, 12
    7, 9, 13
    0, 8, 14
    8, 9, 15
    9, 0, 16
    '''
    # set all edges
90
91
    g.edata['l'] = th.zeros((17, D))
    assert _pfc(g.edata['l']) == [0.] * 17
Minjie Wang's avatar
Minjie Wang committed
92
    # pop edges
93
    old_len = len(g.edata)
Minjie Wang's avatar
Minjie Wang committed
94
    assert _pfc(g.pop_e_repr('l')) == [0.] * 17
95
96
    assert len(g.edata) == old_len - 1
    g.edata['l'] = th.zeros((17, D))
Minjie Wang's avatar
Minjie Wang committed
97
98
99
    # set partial edges (many-many)
    u = th.tensor([0, 0, 2, 5, 9])
    v = th.tensor([1, 3, 9, 9, 0])
100
    g.edges[u, v].data['l'] = th.ones((5, D))
Minjie Wang's avatar
Minjie Wang committed
101
102
    truth = [0.] * 17
    truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
103
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
104
105
106
    # set partial edges (many-one)
    u = th.tensor([3, 4, 6])
    v = th.tensor([9])
107
    g.edges[u, v].data['l'] = th.ones((3, D))
Minjie Wang's avatar
Minjie Wang committed
108
    truth[5] = truth[7] = truth[11] = 1.
109
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
110
111
112
    # set partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([4, 5, 6])
113
    g.edges[u, v].data['l'] = th.ones((3, D))
Minjie Wang's avatar
Minjie Wang committed
114
    truth[6] = truth[8] = truth[10] = 1.
115
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
116
117
118
    # get partial edges (many-many)
    u = th.tensor([0, 6, 0])
    v = th.tensor([6, 9, 7])
119
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Minjie Wang's avatar
Minjie Wang committed
120
121
122
    # get partial edges (many-one)
    u = th.tensor([5, 6, 7])
    v = th.tensor([9])
123
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Minjie Wang's avatar
Minjie Wang committed
124
125
126
    # get partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([3, 4, 5])
127
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
128

129
130
def test_batch_setter_autograd():
    g = generate_graph(grad=True)
131
    h1 = g.ndata['h']
132
133
134
    # partial set
    v = th.tensor([1, 2, 8])
    hh = Variable(th.zeros((len(v), D)), requires_grad=True)
135
136
    g.nodes[v].data['h'] = hh
    h2 = g.ndata['h']
137
138
139
140
    h2.backward(th.ones((10, D)) * 2)
    check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
    check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def test_nx_conversion():
    # check conversion between networkx and DGLGraph

    def _check_nx_feature(nxg, nf, ef):
        num_nodes = len(nxg)
        num_edges = nxg.size()
        if num_nodes > 0:
            node_feat = ddict(list)
            for nid, attr in nxg.nodes(data=True):
                assert len(attr) == len(nf)
                for k in nxg.nodes[nid]:
                    node_feat[k].append(attr[k].unsqueeze(0))
            for k in node_feat:
                feat = th.cat(node_feat[k], dim=0)
                assert U.allclose(feat, nf[k])
        else:
            assert len(nf) == 0
        if num_edges > 0:
            edge_feat = ddict(lambda: [0] * num_edges)
            for u, v, attr in nxg.edges(data=True):
                assert len(attr) == len(ef) + 1 # extra id
                eid = attr['id']
                for k in ef:
                    edge_feat[k][eid] = attr[k].unsqueeze(0)
            for k in edge_feat:
                feat = th.cat(edge_feat[k], dim=0)
                assert U.allclose(feat, ef[k])
        else:
            assert len(ef) == 0

    n1 = th.randn(5, 3)
    n2 = th.randn(5, 10)
    n3 = th.randn(5, 4)
    e1 = th.randn(4, 5)
    e2 = th.randn(4, 7)
    g = DGLGraph(multigraph=True)
    g.add_nodes(5)
    g.add_edges([0,1,3,4], [2,4,0,3])
    g.ndata.update({'n1': n1, 'n2': n2, 'n3': n3})
    g.edata.update({'e1': e1, 'e2': e2})

    # convert to networkx
    nxg = g.to_networkx(node_attrs=['n1', 'n3'], edge_attrs=['e1', 'e2'])
    assert len(nxg) == 5
    assert nxg.size() == 4
    _check_nx_feature(nxg, {'n1': n1, 'n3': n3}, {'e1': e1, 'e2': e2})

    # convert to DGLGraph
    # use id feature to test non-tensor copy
    g.from_networkx(nxg, node_attrs=['n1'], edge_attrs=['e1', 'id'])
    assert g.number_of_nodes() == 5
    assert g.number_of_edges() == 4
    assert U.allclose(g.get_n_repr()['n1'], n1)
    assert U.allclose(g.get_e_repr()['e1'], e1)
    assert th.equal(g.get_e_repr()['id'], th.arange(4))

    g.pop_e_repr('id')

    # test modifying DGLGraph
    new_n = th.randn(2, 3)
    new_e = th.randn(3, 5)
    g.add_nodes(2, data={'n1': new_n})
    # add three edges, one is a multi-edge
    g.add_edges([3, 6, 0], [4, 5, 2], data={'e1': new_e})
    n1 = th.cat((n1, new_n), dim=0)
    e1 = th.cat((e1, new_e), dim=0)
    # convert to networkx again
    nxg = g.to_networkx(node_attrs=['n1'], edge_attrs=['e1'])
    assert len(nxg) == 7
    assert nxg.size() == 7
    _check_nx_feature(nxg, {'n1': n1}, {'e1': e1})

213
214
def test_batch_send():
    g = generate_graph()
215
216
217
    def _fmsg(edges):
        assert edges.src['h'].shape == (5, D)
        return {'m' : edges.src['h']}
Minjie Wang's avatar
Minjie Wang committed
218
    g.register_message_func(_fmsg)
219
    # many-many send
220
221
    u = th.tensor([0, 0, 0, 0, 0])
    v = th.tensor([1, 2, 3, 4, 5])
222
    g.send((u, v))
223
    # one-many send
224
225
    u = th.tensor([0])
    v = th.tensor([1, 2, 3, 4, 5])
226
    g.send((u, v))
227
    # many-one send
228
229
    u = th.tensor([1, 2, 3, 4, 5])
    v = th.tensor([9])
230
    g.send((u, v))
231

232
def test_batch_recv():
Minjie Wang's avatar
Minjie Wang committed
233
    # basic recv test
234
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
235
236
237
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
Minjie Wang's avatar
Minjie Wang committed
238
239
240
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
    reduce_msg_shapes.clear()
241
    g.send((u, v))
Minjie Wang's avatar
Minjie Wang committed
242
243
244
245
    g.recv(th.unique(v))
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

246
247
248
249
250
251
252
253
254
255
256
257
def test_apply_nodes():
    def _upd(nodes):
        return {'h' : nodes.data['h'] * 2}
    g = generate_graph()
    g.register_apply_node_func(_upd)
    old = g.ndata['h']
    g.apply_nodes()
    assert U.allclose(old * 2, g.ndata['h'])
    u = th.tensor([0, 3, 4, 6])
    g.apply_nodes(lambda nodes : {'h' : nodes.data['h'] * 0.}, u)
    assert U.allclose(g.ndata['h'][u], th.zeros((4, D)))

258
def test_apply_edges():
259
260
261
    def _upd(edges):
        return {'w' : edges.data['w'] * 2}
    g = generate_graph()
262
    g.register_apply_edge_func(_upd)
263
    old = g.edata['w']
264
    g.apply_edges()
265
    assert U.allclose(old * 2, g.edata['w'])
266
267
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
268
    g.apply_edges(lambda edges : {'w' : edges.data['w'] * 0.}, (u, v))
269
    eid = g.edge_ids(u, v)
270
    assert U.allclose(g.edata['w'][eid], th.zeros((6, D)))
271

272
273
def test_update_routines():
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
274
275
276
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
277

278
    # send_and_recv
279
    reduce_msg_shapes.clear()
280
281
    u = [0, 0, 0, 4, 5, 6]
    v = [1, 2, 3, 9, 9, 9]
282
    g.send_and_recv((u, v))
283
284
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()
285
286
287
    try:
        g.send_and_recv([u, v])
        assert False
288
    except:
289
        pass
290

291
    # pull
292
293
    v = th.tensor([1, 2, 3, 9])
    reduce_msg_shapes.clear()
294
    g.pull(v)
295
296
297
    assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
    reduce_msg_shapes.clear()

298
    # push
299
300
    v = th.tensor([0, 1, 2, 3])
    reduce_msg_shapes.clear()
301
    g.push(v)
302
303
304
305
306
307
308
309
310
    assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
    reduce_msg_shapes.clear()

    # update_all
    reduce_msg_shapes.clear()
    g.update_all()
    assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
    reduce_msg_shapes.clear()

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def test_recv_0deg():
    # test recv with 0deg nodes;
    g = DGLGraph()
    g.add_nodes(2)
    g.add_edge(0, 1)
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
    def _apply(nodes):
        return {'h' : nodes.data['h'] * 2}
    def _init2(shape, dtype, ctx, ids):
        return 2 + th.zeros(shape, dtype=dtype, device=ctx)
    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    g.set_n_initializer(_init2, 'h')
    # test#1: recv both 0deg and non-0deg nodes
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.send((0, 1))
    g.recv([0, 1])
    new = g.ndata.pop('h')
    # 0deg check: initialized with the func and got applied
    assert U.allclose(new[0], th.full((5,), 4))
    # non-0deg check
    assert U.allclose(new[1], th.sum(old, 0) * 2)

    # test#2: recv only 0deg node is equal to apply
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.send((0, 1))
    g.recv(0)
    new = g.ndata.pop('h')
    # 0deg check: equal to apply_nodes
    assert U.allclose(new[0], 2 * old[0])
    # non-0deg check: untouched
    assert U.allclose(new[1], old[1])

def test_recv_0deg_newfld():
    # test recv with 0deg nodes; the reducer also creates a new field
    g = DGLGraph()
    g.add_nodes(2)
    g.add_edge(0, 1)
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h1' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
    def _apply(nodes):
        return {'h1' : nodes.data['h1'] * 2}
    def _init2(shape, dtype, ctx, ids):
        return 2 + th.zeros(shape, dtype=dtype, device=ctx)
    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    # test#1: recv both 0deg and non-0deg nodes
    old = th.randn((2, 5))
    g.set_n_initializer(_init2, 'h1')
    g.ndata['h'] = old
    g.send((0, 1))
    g.recv([0, 1])
    new = g.ndata.pop('h1')
    # 0deg check: initialized with the func and got applied
    assert U.allclose(new[0], th.full((5,), 4))
    # non-0deg check
    assert U.allclose(new[1], th.sum(old, 0) * 2)

    # test#2: recv only 0deg node
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.ndata['h1'] = th.full((2, 5), -1)  # this is necessary
    g.send((0, 1))
    g.recv(0)
    new = g.ndata.pop('h1')
    # 0deg check: fallback to apply
    assert U.allclose(new[0], th.full((5,), -2))
    # non-0deg check: not changed
    assert U.allclose(new[1], th.full((5,), -1))

def test_update_all_0deg():
    # test#1
392
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
393
    g.add_nodes(5)
394
395
396
397
    g.add_edge(1, 0)
    g.add_edge(2, 0)
    g.add_edge(3, 0)
    g.add_edge(4, 0)
398
399
400
401
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
402
403
    def _apply(nodes):
        return {'h' : nodes.data['h'] * 2}
404
405
406
    def _init2(shape, dtype, ctx, ids):
        return 2 + th.zeros(shape, dtype=dtype, device=ctx)
    g.set_n_initializer(_init2, 'h')
407
    old_repr = th.randn(5, 5)
408
    g.ndata['h'] = old_repr
409
    g.update_all(_message, _reduce, _apply)
410
    new_repr = g.ndata['h']
411
412
    # the first row of the new_repr should be the sum of all the node
    # features; while the 0-deg nodes should be initialized by the
413
414
415
    # initializer and applied with UDF.
    assert U.allclose(new_repr[1:], 2*(2+th.zeros((4,5))))
    assert U.allclose(new_repr[0], 2 * old_repr.sum(0))
416

417
418
419
420
421
422
423
424
425
    # test#2: graph with no edge
    g = DGLGraph()
    g.add_nodes(5)
    g.set_n_initializer(_init2, 'h')
    g.ndata['h'] = old_repr
    g.update_all(_message, _reduce, _apply)
    new_repr = g.ndata['h']
    # should fallback to apply
    assert U.allclose(new_repr, 2*old_repr)
426

427
def test_pull_0deg():
428
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
429
    g.add_nodes(2)
430
    g.add_edge(0, 1)
431
432
433
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
        return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
    def _apply(nodes):
        return {'h' : nodes.data['h'] * 2}
    def _init2(shape, dtype, ctx, ids):
        return 2 + th.zeros(shape, dtype=dtype, device=ctx)
    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    g.set_n_initializer(_init2, 'h')
    # test#1: pull both 0deg and non-0deg nodes
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.pull([0, 1])
    new = g.ndata.pop('h')
    # 0deg check: initialized with the func and got applied
    assert U.allclose(new[0], th.full((5,), 4))
    # non-0deg check
    assert U.allclose(new[1], th.sum(old, 0) * 2)

    # test#2: pull only 0deg node
    old = th.randn((2, 5))
    g.ndata['h'] = old
    g.pull(0)
    new = g.ndata.pop('h')
    # 0deg check: fallback to apply
    assert U.allclose(new[0], 2*old[0])
    # non-0deg check: not touched
    assert U.allclose(new[1], old[1])
462

463
464
def _disabled_test_send_twice():
    # TODO(minjie): please re-enable this unittest after the send code problem is fixed.
465
466
467
468
    g = DGLGraph()
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(2, 1)
469
470
471
472
473
474
    def _message_a(edges):
        return {'a': edges.src['a']}
    def _message_b(edges):
        return {'a': edges.src['a'] * 3}
    def _reduce(nodes):
        return {'a': nodes.mailbox['a'].max(1)[0]}
475
476

    old_repr = th.randn(3, 5)
477
478
479
480
481
    g.ndata['a'] = old_repr
    g.send((0, 1), _message_a)
    g.send((0, 1), _message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
482
    assert U.allclose(new_repr[1], old_repr[0] * 3)
483

484
485
486
487
488
    g.ndata['a'] = old_repr
    g.send((0, 1), _message_a)
    g.send((2, 1), _message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
489
    assert U.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])
490
491
492
493
494
495
496
497
498

def test_send_multigraph():
    g = DGLGraph(multigraph=True)
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(2, 1)

499
500
501
502
503
504
    def _message_a(edges):
        return {'a': edges.data['a']}
    def _message_b(edges):
        return {'a': edges.data['a'] * 3}
    def _reduce(nodes):
        return {'a': nodes.mailbox['a'].max(1)[0]}
505
506
507
508
509
510

    def answer(*args):
        return th.stack(args, 0).max(0)[0]

    # send by eid
    old_repr = th.randn(4, 5)
511
512
513
514
515
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send([0, 2], message_func=_message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
516
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))
517

518
519
520
521
522
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send([0, 2, 3], message_func=_message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
523
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
524
525

    # send on multigraph
526
527
528
529
530
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send(([0, 2], [1, 1]), _message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
531
    assert U.allclose(new_repr[1], old_repr.max(0)[0])
532
533

    # consecutive send and send_on
534
535
536
537
538
539
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send((2, 1), _message_a)
    g.send([0, 1], message_func=_message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
540
    assert U.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))
541
542

    # consecutive send_on
543
544
545
546
547
548
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send(0, message_func=_message_a)
    g.send(1, message_func=_message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
549
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))
550
551

    # send_and_recv_on
552
553
554
555
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send_and_recv([0, 2, 3], message_func=_message_a, reduce_func=_reduce)
    new_repr = g.ndata['a']
556
557
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
    assert U.allclose(new_repr[[0, 2]], th.zeros(2, 5))
558

559
560
561
562
563
564
565
566
def test_dynamic_addition():
    N = 3
    D = 1

    g = DGLGraph()

    # Test node addition
    g.add_nodes(N)
567
568
    g.ndata.update({'h1': th.randn(N, D),
                    'h2': th.randn(N, D)})
569
    g.add_nodes(3)
570
    assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3
571
572
573
574

    # Test edge addition
    g.add_edge(0, 1)
    g.add_edge(1, 0)
575
576
577
    g.edata.update({'h1': th.randn(2, D),
                    'h2': th.randn(2, D)})
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2
578
579

    g.add_edges([0, 2], [2, 0])
580
581
    g.edata['h1'] = th.randn(4, D)
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4
582
583

    g.add_edge(1, 2)
584
585
    g.edges[4].data['h1'] = th.randn(1, D)
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5
586

587

588
if __name__ == '__main__':
589
    test_nx_conversion()
590
    test_batch_setter_getter()
591
    test_batch_setter_autograd()
592
    test_batch_send()
593
    test_batch_recv()
594
    test_apply_nodes()
595
    test_apply_edges()
596
    test_update_routines()
597
598
599
    test_recv_0deg()
    test_recv_0deg_newfld()
    test_update_all_0deg()
600
    test_pull_0deg()
601
    test_send_multigraph()
602
    test_dynamic_addition()