test_basics.py 14.5 KB
Newer Older
Da Zheng's avatar
Da Zheng committed
1
2
3
4
5
import os
os.environ['DGLBACKEND'] = 'mxnet'
import mxnet as mx
import numpy as np
from dgl.graph import DGLGraph
6
import dgl
7
import scipy.sparse as spsp
Da Zheng's avatar
Da Zheng committed
8
9
10
11
12
13

D = 5
reduce_msg_shapes = set()

def check_eq(a, b):
    assert a.shape == b.shape
Da Zheng's avatar
Da Zheng committed
14
    assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))
Da Zheng's avatar
Da Zheng committed
15

16
17
18
19
def message_func(edges):
    assert len(edges.src['h'].shape) == 2
    assert edges.src['h'].shape[1] == D
    return {'m' : edges.src['h']}
Da Zheng's avatar
Da Zheng committed
20

21
22
def reduce_func(nodes):
    msgs = nodes.mailbox['m']
Da Zheng's avatar
Da Zheng committed
23
24
25
26
27
    reduce_msg_shapes.add(tuple(msgs.shape))
    assert len(msgs.shape) == 3
    assert msgs.shape[2] == D
    return {'m' : mx.nd.sum(msgs, 1)}

28
29
def apply_node_func(nodes):
    return {'h' : nodes.data['h'] + nodes.data['m']}
Da Zheng's avatar
Da Zheng committed
30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def generate_graph(grad=False, readonly=False):
    if readonly:
        row_idx = []
        col_idx = []
        for i in range(1, 9):
            row_idx.append(0)
            col_idx.append(i)
            row_idx.append(i)
            col_idx.append(9)
        row_idx.append(9)
        col_idx.append(0)
        ones = np.ones(shape=(len(row_idx)))
        csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(10, 10))
        g = DGLGraph(csr, readonly=True)
        ncol = mx.nd.random.normal(shape=(10, D))
46
        ecol = mx.nd.random.normal(shape=(17, D))
47
48
        if grad:
            ncol.attach_grad()
49
            ecol.attach_grad()
50
        g.ndata['h'] = ncol
51
52
53
        g.edata['w'] = ecol
        g.set_n_initializer(dgl.init.zero_initializer)
        g.set_e_initializer(dgl.init.zero_initializer)
54
55
56
57
58
59
60
61
62
63
64
        return g
    else:
        g = DGLGraph()
        g.add_nodes(10) # 10 nodes.
        # create a graph where 0 is the source and 9 is the sink
        for i in range(1, 9):
            g.add_edge(0, i)
            g.add_edge(i, 9)
        # add a back flow from 9 to 0
        g.add_edge(9, 0)
        ncol = mx.nd.random.normal(shape=(10, D))
65
        ecol = mx.nd.random.normal(shape=(17, D))
66
67
        if grad:
            ncol.attach_grad()
68
            ecol.attach_grad()
69
        g.ndata['h'] = ncol
70
71
72
        g.edata['w'] = ecol
        g.set_n_initializer(dgl.init.zero_initializer)
        g.set_e_initializer(dgl.init.zero_initializer)
73
        return g
Da Zheng's avatar
Da Zheng committed
74
75
76
77
78
79
80

def test_batch_setter_getter():
    def _pfc(x):
        return list(x.asnumpy()[:,0])
    g = generate_graph()
    # set all nodes
    g.set_n_repr({'h' : mx.nd.zeros((10, D))})
81
    assert _pfc(g.ndata['h']) == [0.] * 10
Da Zheng's avatar
Da Zheng committed
82
83
    # pop nodes
    assert _pfc(g.pop_n_repr('h')) == [0.] * 10
84
    assert len(g.ndata) == 0
Da Zheng's avatar
Da Zheng committed
85
86
87
88
    g.set_n_repr({'h' : mx.nd.zeros((10, D))})
    # set partial nodes
    u = mx.nd.array([1, 3, 5], dtype='int64')
    g.set_n_repr({'h' : mx.nd.ones((3, D))}, u)
89
    assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
Da Zheng's avatar
Da Zheng committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    # get partial nodes
    u = mx.nd.array([1, 2, 3], dtype='int64')
    assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]

    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
    0, 2, 2
    2, 9, 3
    0, 3, 4
    3, 9, 5
    0, 4, 6
    4, 9, 7
    0, 5, 8
    5, 9, 9
    0, 6, 10
    6, 9, 11
    0, 7, 12
    7, 9, 13
    0, 8, 14
    8, 9, 15
    9, 0, 16
    '''
    # set all edges
115
116
    g.edata['l'] = mx.nd.zeros((17, D))
    assert _pfc(g.edata['l']) == [0.] * 17
Da Zheng's avatar
Da Zheng committed
117
    # pop edges
118
    old_len = len(g.edata)
Da Zheng's avatar
Da Zheng committed
119
    assert _pfc(g.pop_e_repr('l')) == [0.] * 17
120
    assert len(g.edata) == old_len - 1
121
    g.edata['l'] = mx.nd.zeros((17, D))
Da Zheng's avatar
Da Zheng committed
122
123
124
    # set partial edges (many-many)
    u = mx.nd.array([0, 0, 2, 5, 9], dtype='int64')
    v = mx.nd.array([1, 3, 9, 9, 0], dtype='int64')
125
    g.edges[u, v].data['l'] = mx.nd.ones((5, D))
Da Zheng's avatar
Da Zheng committed
126
127
    truth = [0.] * 17
    truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
128
    assert _pfc(g.edata['l']) == truth
Da Zheng's avatar
Da Zheng committed
129
130
131
    # set partial edges (many-one)
    u = mx.nd.array([3, 4, 6], dtype='int64')
    v = mx.nd.array([9], dtype='int64')
132
    g.edges[u, v].data['l'] = mx.nd.ones((3, D))
Da Zheng's avatar
Da Zheng committed
133
    truth[5] = truth[7] = truth[11] = 1.
134
    assert _pfc(g.edata['l']) == truth
Da Zheng's avatar
Da Zheng committed
135
136
137
    # set partial edges (one-many)
    u = mx.nd.array([0], dtype='int64')
    v = mx.nd.array([4, 5, 6], dtype='int64')
138
    g.edges[u, v].data['l'] = mx.nd.ones((3, D))
Da Zheng's avatar
Da Zheng committed
139
    truth[6] = truth[8] = truth[10] = 1.
140
    assert _pfc(g.edata['l']) == truth
Da Zheng's avatar
Da Zheng committed
141
142
143
    # get partial edges (many-many)
    u = mx.nd.array([0, 6, 0], dtype='int64')
    v = mx.nd.array([6, 9, 7], dtype='int64')
144
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Da Zheng's avatar
Da Zheng committed
145
146
147
    # get partial edges (many-one)
    u = mx.nd.array([5, 6, 7], dtype='int64')
    v = mx.nd.array([9], dtype='int64')
148
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Da Zheng's avatar
Da Zheng committed
149
150
151
    # get partial edges (one-many)
    u = mx.nd.array([0], dtype='int64')
    v = mx.nd.array([3, 4, 5], dtype='int64')
152
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
Da Zheng's avatar
Da Zheng committed
153
154
155

def test_batch_setter_autograd():
    with mx.autograd.record():
156
        g = generate_graph(grad=True, readonly=True)
157
        h1 = g.ndata['h']
Da Zheng's avatar
Da Zheng committed
158
        h1.attach_grad()
Da Zheng's avatar
Da Zheng committed
159
160
161
        # partial set
        v = mx.nd.array([1, 2, 8], dtype='int64')
        hh = mx.nd.zeros((len(v), D))
Da Zheng's avatar
Da Zheng committed
162
        hh.attach_grad()
Da Zheng's avatar
Da Zheng committed
163
        g.set_n_repr({'h' : hh}, v)
164
        h2 = g.ndata['h']
Da Zheng's avatar
Da Zheng committed
165
166
167
168
169
170
    h2.backward(mx.nd.ones((10, D)) * 2)
    check_eq(h1.grad[:,0], mx.nd.array([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
    check_eq(hh.grad[:,0], mx.nd.array([2., 2., 2.]))

def test_batch_send():
    g = generate_graph()
171
172
173
    def _fmsg(edges):
        assert edges.src['h'].shape == (5, D)
        return {'m' : edges.src['h']}
Da Zheng's avatar
Da Zheng committed
174
175
176
177
    g.register_message_func(_fmsg)
    # many-many send
    u = mx.nd.array([0, 0, 0, 0, 0], dtype='int64')
    v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
178
    g.send((u, v))
Da Zheng's avatar
Da Zheng committed
179
180
181
    # one-many send
    u = mx.nd.array([0], dtype='int64')
    v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
182
    g.send((u, v))
Da Zheng's avatar
Da Zheng committed
183
184
185
    # many-one send
    u = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
    v = mx.nd.array([9], dtype='int64')
186
    g.send((u, v))
Da Zheng's avatar
Da Zheng committed
187

188
def check_batch_recv(readonly):
Da Zheng's avatar
Da Zheng committed
189
    # basic recv test
190
    g = generate_graph(readonly=readonly)
Da Zheng's avatar
Da Zheng committed
191
192
193
194
195
196
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
    u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
    v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
    reduce_msg_shapes.clear()
197
    g.send((u, v))
Da Zheng's avatar
Da Zheng committed
198
199
200
201
    #g.recv(th.unique(v))
    #assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    #reduce_msg_shapes.clear()

202
203
204
205
def test_batch_recv():
    check_batch_recv(True)
    check_batch_recv(False)

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def test_apply_nodes():
    def _upd(nodes):
        return {'h' : nodes.data['h'] * 2}
    g = generate_graph()
    g.register_apply_node_func(_upd)
    old = g.ndata['h']
    g.apply_nodes()
    assert np.allclose((old * 2).asnumpy(), g.ndata['h'].asnumpy())
    u = mx.nd.array([0, 3, 4, 6], dtype=np.int64)
    g.apply_nodes(lambda nodes : {'h' : nodes.data['h'] * 0.}, u)
    h = g.ndata['h'][u].asnumpy()
    assert np.allclose(h, np.zeros(shape=(4, D), dtype=h.dtype))

def test_apply_edges():
    def _upd(edges):
        return {'w' : edges.data['w'] * 2}
    g = generate_graph()
    g.register_apply_edge_func(_upd)
    old = g.edata['w']
    g.apply_edges()
    assert np.allclose((old * 2).asnumpy(), g.edata['w'].asnumpy())
    u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype=np.int64)
    v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype=np.int64)
    g.apply_edges(lambda edges : {'w' : edges.data['w'] * 0.}, (u, v))
    eid = g.edge_ids(u, v)
    w = g.edata['w'][eid].asnumpy()
    assert np.allclose(w, np.zeros(shape=(6, D), dtype=w.dtype))

234
235
def check_update_routines(readonly):
    g = generate_graph(readonly=readonly)
Da Zheng's avatar
Da Zheng committed
236
237
238
239
240
241
242
243
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)

    # send_and_recv
    reduce_msg_shapes.clear()
    u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
    v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
244
    g.send_and_recv((u, v))
Da Zheng's avatar
Da Zheng committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

    # pull
    v = mx.nd.array([1, 2, 3, 9], dtype='int64')
    reduce_msg_shapes.clear()
    g.pull(v)
    assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
    reduce_msg_shapes.clear()

    # push
    v = mx.nd.array([0, 1, 2, 3], dtype='int64')
    reduce_msg_shapes.clear()
    g.push(v)
    assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
    reduce_msg_shapes.clear()

    # update_all
    reduce_msg_shapes.clear()
    g.update_all()
    assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
    reduce_msg_shapes.clear()

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def test_update_routines():
    check_update_routines(True)
    check_update_routines(False)

def check_reduce_0deg(readonly):
    if readonly:
        row_idx = []
        col_idx = []
        for i in range(1, 5):
            row_idx.append(i)
            col_idx.append(0)
        ones = np.ones(shape=(len(row_idx)))
        csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(5, 5))
        g = DGLGraph(csr, readonly=True)
    else:
        g = DGLGraph()
        g.add_nodes(5)
        g.add_edge(1, 0)
        g.add_edge(2, 0)
        g.add_edge(3, 0)
        g.add_edge(4, 0)
289
290
291
292
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
293
294
295
    def _init2(shape, dtype, ctx, ids):
        return 2 + mx.nd.zeros(shape, dtype=dtype, ctx=ctx)
    g.set_n_initializer(_init2, 'h')
Da Zheng's avatar
Da Zheng committed
296
    old_repr = mx.nd.random.normal(shape=(5, 5))
Da Zheng's avatar
Da Zheng committed
297
    g.set_n_repr({'h': old_repr})
Da Zheng's avatar
Da Zheng committed
298
    g.update_all(_message, _reduce)
299
    new_repr = g.ndata['h']
Da Zheng's avatar
Da Zheng committed
300

301
    assert np.allclose(new_repr[1:].asnumpy(), 2+np.zeros((4, 5)))
Da Zheng's avatar
Da Zheng committed
302
303
    assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy())

304
305
306
307
def test_reduce_0deg():
    check_reduce_0deg(True)
    check_reduce_0deg(False)

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def test_recv_0deg_newfld():
    # test recv with 0deg nodes; the reducer also creates a new field
    g = DGLGraph()
    g.add_nodes(2)
    g.add_edge(0, 1)
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h1' : nodes.data['h'] + mx.nd.sum(nodes.mailbox['m'], 1)}
    def _apply(nodes):
        return {'h1' : nodes.data['h1'] * 2}
    def _init2(shape, dtype, ctx, ids):
        return 2 + mx.nd.zeros(shape=shape, dtype=dtype, ctx=ctx)
    g.register_message_func(_message)
    g.register_reduce_func(_reduce)
    g.register_apply_node_func(_apply)
    # test#1: recv both 0deg and non-0deg nodes
    old = mx.nd.random.normal(shape=(2, 5))
    g.set_n_initializer(_init2, 'h1')
    g.ndata['h'] = old
    g.send((0, 1))
    g.recv([0, 1])
    new = g.ndata.pop('h1')
    # 0deg check: initialized with the func and got applied
    assert np.allclose(new[0].asnumpy(), np.full((5,), 4))
    # non-0deg check
    assert np.allclose(new[1].asnumpy(), mx.nd.sum(old, 0).asnumpy() * 2)

    # test#2: recv only 0deg node
    old = mx.nd.random.normal(shape=(2, 5))
    g.ndata['h'] = old
    g.ndata['h1'] = mx.nd.full((2, 5), -1)  # this is necessary
    g.send((0, 1))
    g.recv(0)
    new = g.ndata.pop('h1')
    # 0deg check: fallback to apply
    assert np.allclose(new[0].asnumpy(), np.full((5,), -2))
    # non-0deg check: not changed
    assert np.allclose(new[1].asnumpy(), np.full((5,), -1))

def test_update_all_0deg():
    # test#1
    g = DGLGraph()
    g.add_nodes(5)
    g.add_edge(1, 0)
    g.add_edge(2, 0)
    g.add_edge(3, 0)
    g.add_edge(4, 0)
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.data['h'] + mx.nd.sum(nodes.mailbox['m'], 1)}
    def _apply(nodes):
        return {'h' : nodes.data['h'] * 2}
    def _init2(shape, dtype, ctx, ids):
        return 2 + mx.nd.zeros(shape, dtype=dtype, ctx=ctx)
    g.set_n_initializer(_init2, 'h')
    old_repr = mx.nd.random.normal(shape=(5, 5))
    g.ndata['h'] = old_repr
    g.update_all(_message, _reduce, _apply)
    new_repr = g.ndata['h']
    # the first row of the new_repr should be the sum of all the node
    # features; while the 0-deg nodes should be initialized by the
    # initializer and applied with UDF.
    assert np.allclose(new_repr[1:].asnumpy(), 2*(2+np.zeros((4,5))))
    assert np.allclose(new_repr[0].asnumpy(), 2 * mx.nd.sum(old_repr, 0).asnumpy())

    # test#2: graph with no edge
    g = DGLGraph()
    g.add_nodes(5)
    g.set_n_initializer(_init2, 'h')
    g.ndata['h'] = old_repr
    g.update_all(_message, _reduce, _apply)
    new_repr = g.ndata['h']
    # should fallback to apply
    assert np.allclose(new_repr.asnumpy(), 2*old_repr.asnumpy())


386
387
388
389
390
391
392
393
394
395
396
397
398
def check_pull_0deg(readonly):
    if readonly:
        row_idx = []
        col_idx = []
        row_idx.append(0)
        col_idx.append(1)
        ones = np.ones(shape=(len(row_idx)))
        csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(2, 2))
        g = DGLGraph(csr, readonly=True)
    else:
        g = DGLGraph()
        g.add_nodes(2)
        g.add_edge(0, 1)
399
400
401
402
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.mailbox['m'].sum(1)}
403
404
405
406
407
    def _apply(nodes):
        return {'h' : nodes.data['h'] * 2}
    def _init2(shape, dtype, ctx, ids):
        return 2 + mx.nd.zeros(shape, dtype=dtype, ctx=ctx)
    g.set_n_initializer(_init2, 'h')
Da Zheng's avatar
Da Zheng committed
408
    old_repr = mx.nd.random.normal(shape=(2, 5))
409
410
411
412

    # test#1: pull only 0-deg node
    g.ndata['h'] = old_repr
    g.pull(0, _message, _reduce, _apply)
413
    new_repr = g.ndata['h']
414
415
416
    # 0deg check: equal to apply_nodes
    assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy() * 2)
    # non-0deg check: untouched
Da Zheng's avatar
Da Zheng committed
417
418
    assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy())

419
420
421
    # test#2: pull only non-deg node
    g.ndata['h'] = old_repr
    g.pull(1, _message, _reduce, _apply)
422
    new_repr = g.ndata['h']
423
    # 0deg check: untouched
Da Zheng's avatar
Da Zheng committed
424
    assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
425
426
427
428
429
430
431
432
433
434
435
436
    # non-0deg check: recved node0 and got applied
    assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy() * 2)

    # test#3: pull only both nodes
    g.ndata['h'] = old_repr
    g.pull([0, 1], _message, _reduce, _apply)
    new_repr = g.ndata['h']
    # 0deg check: init and applied
    t = mx.nd.zeros(shape=(2,5)) + 4
    assert np.allclose(new_repr[0].asnumpy(), t.asnumpy())
    # non-0deg check: recv node0 and applied
    assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy() * 2)
Da Zheng's avatar
Da Zheng committed
437

438
439
440
441
def test_pull_0deg():
    check_pull_0deg(True)
    check_pull_0deg(False)

Da Zheng's avatar
Da Zheng committed
442
443
if __name__ == '__main__':
    test_batch_setter_getter()
Da Zheng's avatar
Da Zheng committed
444
    test_batch_setter_autograd()
Da Zheng's avatar
Da Zheng committed
445
446
    test_batch_send()
    test_batch_recv()
447
448
    test_apply_nodes()
    test_apply_edges()
Da Zheng's avatar
Da Zheng committed
449
450
    test_update_routines()
    test_reduce_0deg()
451
452
    test_recv_0deg_newfld()
    test_update_all_0deg()
Da Zheng's avatar
Da Zheng committed
453
    test_pull_0deg()