test_basics.py 12.2 KB
Newer Older
1
import torch as th
2
3
from torch.autograd import Variable
import numpy as np
4
import dgl
5
from dgl.graph import DGLGraph
6
import utils as U
7
8
9
10

D = 5
reduce_msg_shapes = set()

11
12
13
14
def check_eq(a, b):
    assert a.shape == b.shape
    assert th.sum(a == b) == int(np.prod(list(a.shape)))

15
16
17
18
def message_func(edges):
    assert len(edges.src['h'].shape) == 2
    assert edges.src['h'].shape[1] == D
    return {'m' : edges.src['h']}
19

20
21
def reduce_func(nodes):
    msgs = nodes.mailbox['m']
Minjie Wang's avatar
Minjie Wang committed
22
23
24
    reduce_msg_shapes.add(tuple(msgs.shape))
    assert len(msgs.shape) == 3
    assert msgs.shape[2] == D
Minjie Wang's avatar
Minjie Wang committed
25
    return {'accum' : th.sum(msgs, 1)}
Minjie Wang's avatar
Minjie Wang committed
26

27
28
def apply_node_func(nodes):
    return {'h' : nodes.data['h'] + nodes.data['accum']}
Minjie Wang's avatar
Minjie Wang committed
29

30
def generate_graph(grad=False):
31
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
32
    g.add_nodes(10) # 10 nodes.
33
    # create a graph where 0 is the source and 9 is the sink
Minjie Wang's avatar
Minjie Wang committed
34
    # 17 edges
35
36
37
38
39
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
Minjie Wang's avatar
Minjie Wang committed
40
    ncol = Variable(th.randn(10, D), requires_grad=grad)
Minjie Wang's avatar
Minjie Wang committed
41
    ecol = Variable(th.randn(17, D), requires_grad=grad)
42
43
    g.ndata['h'] = ncol
    g.edata['w'] = ecol
44
45
    g.set_n_initializer(dgl.init.zero_initializer)
    g.set_e_initializer(dgl.init.zero_initializer)
46
47
48
49
50
51
52
    return g

def test_batch_setter_getter():
    def _pfc(x):
        return list(x.numpy()[:,0])
    g = generate_graph()
    # set all nodes
53
    g.ndata['h'] = th.zeros((10, D))
54
    assert U.allclose(g.ndata['h'], th.zeros((10, D)))
Minjie Wang's avatar
Minjie Wang committed
55
    # pop nodes
56
    old_len = len(g.ndata)
Minjie Wang's avatar
Minjie Wang committed
57
    assert _pfc(g.pop_n_repr('h')) == [0.] * 10
58
59
    assert len(g.ndata) == old_len - 1
    g.ndata['h'] = th.zeros((10, D))
60
61
    # set partial nodes
    u = th.tensor([1, 3, 5])
62
63
    g.nodes[u].data['h'] = th.ones((3, D))
    assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
64
65
    # get partial nodes
    u = th.tensor([1, 2, 3])
66
    assert _pfc(g.nodes[u].data['h']) == [1., 0., 1.]
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
    0, 2, 2
    2, 9, 3
    0, 3, 4
    3, 9, 5
    0, 4, 6
    4, 9, 7
    0, 5, 8
    5, 9, 9
    0, 6, 10
    6, 9, 11
    0, 7, 12
    7, 9, 13
    0, 8, 14
    8, 9, 15
    9, 0, 16
    '''
    # set all edges
89
90
    g.edata['l'] = th.zeros((17, D))
    assert _pfc(g.edata['l']) == [0.] * 17
Minjie Wang's avatar
Minjie Wang committed
91
    # pop edges
92
    old_len = len(g.edata)
Minjie Wang's avatar
Minjie Wang committed
93
    assert _pfc(g.pop_e_repr('l')) == [0.] * 17
94
95
    assert len(g.edata) == old_len - 1
    g.edata['l'] = th.zeros((17, D))
Minjie Wang's avatar
Minjie Wang committed
96
97
98
    # set partial edges (many-many)
    u = th.tensor([0, 0, 2, 5, 9])
    v = th.tensor([1, 3, 9, 9, 0])
99
    g.edges[u, v].data['l'] = th.ones((5, D))
Minjie Wang's avatar
Minjie Wang committed
100
101
    truth = [0.] * 17
    truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
102
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
103
104
105
    # set partial edges (many-one)
    u = th.tensor([3, 4, 6])
    v = th.tensor([9])
106
    g.edges[u, v].data['l'] = th.ones((3, D))
Minjie Wang's avatar
Minjie Wang committed
107
    truth[5] = truth[7] = truth[11] = 1.
108
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
109
110
111
    # set partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([4, 5, 6])
112
    g.edges[u, v].data['l'] = th.ones((3, D))
Minjie Wang's avatar
Minjie Wang committed
113
    truth[6] = truth[8] = truth[10] = 1.
114
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
115
116
117
    # get partial edges (many-many)
    u = th.tensor([0, 6, 0])
    v = th.tensor([6, 9, 7])
118
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Minjie Wang's avatar
Minjie Wang committed
119
120
121
    # get partial edges (many-one)
    u = th.tensor([5, 6, 7])
    v = th.tensor([9])
122
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Minjie Wang's avatar
Minjie Wang committed
123
124
125
    # get partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([3, 4, 5])
126
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
127

128
129
def test_batch_setter_autograd():
    g = generate_graph(grad=True)
130
    h1 = g.ndata['h']
131
132
133
    # partial set
    v = th.tensor([1, 2, 8])
    hh = Variable(th.zeros((len(v), D)), requires_grad=True)
134
135
    g.nodes[v].data['h'] = hh
    h2 = g.ndata['h']
136
137
138
139
    h2.backward(th.ones((10, D)) * 2)
    check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
    check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))

140
141
def test_batch_send():
    g = generate_graph()
142
143
144
    def _fmsg(edges):
        assert edges.src['h'].shape == (5, D)
        return {'m' : edges.src['h']}
Minjie Wang's avatar
Minjie Wang committed
145
    g.register_message_func(_fmsg)
146
    # many-many send
147
148
    u = th.tensor([0, 0, 0, 0, 0])
    v = th.tensor([1, 2, 3, 4, 5])
149
    g.send((u, v))
150
    # one-many send
151
152
    u = th.tensor([0])
    v = th.tensor([1, 2, 3, 4, 5])
153
    g.send((u, v))
154
    # many-one send
155
156
    u = th.tensor([1, 2, 3, 4, 5])
    v = th.tensor([9])
157
    g.send((u, v))
158

159
def test_batch_recv():
Minjie Wang's avatar
Minjie Wang committed
160
    # basic recv test
161
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
162
163
164
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
Minjie Wang's avatar
Minjie Wang committed
165
166
167
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
    reduce_msg_shapes.clear()
168
    g.send((u, v))
Minjie Wang's avatar
Minjie Wang committed
169
170
171
172
    g.recv(th.unique(v))
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

173
174
175
176
177
178
179
180
181
182
183
184
def test_apply_nodes():
    def _upd(nodes):
        return {'h' : nodes.data['h'] * 2}
    g = generate_graph()
    g.register_apply_node_func(_upd)
    old = g.ndata['h']
    g.apply_nodes()
    assert U.allclose(old * 2, g.ndata['h'])
    u = th.tensor([0, 3, 4, 6])
    g.apply_nodes(lambda nodes : {'h' : nodes.data['h'] * 0.}, u)
    assert U.allclose(g.ndata['h'][u], th.zeros((4, D)))

185
def test_apply_edges():
186
187
188
    def _upd(edges):
        return {'w' : edges.data['w'] * 2}
    g = generate_graph()
189
    g.register_apply_edge_func(_upd)
190
    old = g.edata['w']
191
    g.apply_edges()
192
    assert U.allclose(old * 2, g.edata['w'])
193
194
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
195
    g.apply_edges(lambda edges : {'w' : edges.data['w'] * 0.}, (u, v))
196
    eid = g.edge_ids(u, v)
197
    assert U.allclose(g.edata['w'][eid], th.zeros((6, D)))
198

199
200
def test_update_routines():
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
201
202
203
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
204

205
    # send_and_recv
206
    reduce_msg_shapes.clear()
207
208
    u = [0, 0, 0, 4, 5, 6]
    v = [1, 2, 3, 9, 9, 9]
209
    g.send_and_recv((u, v))
210
211
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()
212
213
214
    try:
        g.send_and_recv([u, v])
        assert False
215
    except:
216
        pass
217

218
    # pull
219
220
    v = th.tensor([1, 2, 3, 9])
    reduce_msg_shapes.clear()
221
    g.pull(v)
222
223
224
    assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
    reduce_msg_shapes.clear()

225
    # push
226
227
    v = th.tensor([0, 1, 2, 3])
    reduce_msg_shapes.clear()
228
    g.push(v)
229
230
231
232
233
234
235
236
237
    assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
    reduce_msg_shapes.clear()

    # update_all
    reduce_msg_shapes.clear()
    g.update_all()
    assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
    reduce_msg_shapes.clear()

238
239
def test_reduce_0deg():
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
240
    g.add_nodes(5)
241
242
243
244
    g.add_edge(1, 0)
    g.add_edge(2, 0)
    g.add_edge(3, 0)
    g.add_edge(4, 0)
245
246
247
248
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
249
250
251
    def _init2(shape, dtype, ctx, ids):
        return 2 + th.zeros(shape, dtype=dtype, device=ctx)
    g.set_n_initializer(_init2, 'h')
252
    old_repr = th.randn(5, 5)
253
    g.ndata['h'] = old_repr
Minjie Wang's avatar
Minjie Wang committed
254
    g.update_all(_message, _reduce)
255
    new_repr = g.ndata['h']
256
257
258
259
    # the first row of the new_repr should be the sum of all the node
    # features; while the 0-deg nodes should be initialized by the
    # initializer.
    assert U.allclose(new_repr[1:], 2+th.zeros((4,5)))
260
    assert U.allclose(new_repr[0], old_repr.sum(0))
261

262
def test_pull_0deg():
263
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
264
    g.add_nodes(2)
265
    g.add_edge(0, 1)
266
267
268
269
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.mailbox['m'].sum(1)}
270
    old_repr = th.randn(2, 5)
271
    g.ndata['h'] = old_repr
Minjie Wang's avatar
Minjie Wang committed
272

Minjie Wang's avatar
Minjie Wang committed
273
    g.pull(0, _message, _reduce)
274
    new_repr = g.ndata['h']
275
276
    assert U.allclose(new_repr[0], old_repr[0])
    assert U.allclose(new_repr[1], old_repr[1])
Minjie Wang's avatar
Minjie Wang committed
277

Minjie Wang's avatar
Minjie Wang committed
278
    g.pull(1, _message, _reduce)
279
    new_repr = g.ndata['h']
280
    assert U.allclose(new_repr[1], old_repr[0])
281
282

    old_repr = th.randn(2, 5)
283
    g.ndata['h'] = old_repr
Minjie Wang's avatar
Minjie Wang committed
284
    g.pull([0, 1], _message, _reduce)
285
    new_repr = g.ndata['h']
286
287
    assert U.allclose(new_repr[0], old_repr[0])
    assert U.allclose(new_repr[1], old_repr[0])
288

289
290
def _disabled_test_send_twice():
    # TODO(minjie): please re-enable this unittest after the send code problem is fixed.
291
292
293
294
    g = DGLGraph()
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(2, 1)
295
296
297
298
299
300
    def _message_a(edges):
        return {'a': edges.src['a']}
    def _message_b(edges):
        return {'a': edges.src['a'] * 3}
    def _reduce(nodes):
        return {'a': nodes.mailbox['a'].max(1)[0]}
301
302

    old_repr = th.randn(3, 5)
303
304
305
306
307
    g.ndata['a'] = old_repr
    g.send((0, 1), _message_a)
    g.send((0, 1), _message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
308
    assert U.allclose(new_repr[1], old_repr[0] * 3)
309

310
311
312
313
314
    g.ndata['a'] = old_repr
    g.send((0, 1), _message_a)
    g.send((2, 1), _message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
315
    assert U.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])
316
317
318
319
320
321
322
323
324

def test_send_multigraph():
    g = DGLGraph(multigraph=True)
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(2, 1)

325
326
327
328
329
330
    def _message_a(edges):
        return {'a': edges.data['a']}
    def _message_b(edges):
        return {'a': edges.data['a'] * 3}
    def _reduce(nodes):
        return {'a': nodes.mailbox['a'].max(1)[0]}
331
332
333
334
335
336

    def answer(*args):
        return th.stack(args, 0).max(0)[0]

    # send by eid
    old_repr = th.randn(4, 5)
337
338
339
340
341
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send([0, 2], message_func=_message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
342
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))
343

344
345
346
347
348
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send([0, 2, 3], message_func=_message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
349
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
350
351

    # send on multigraph
352
353
354
355
356
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send(([0, 2], [1, 1]), _message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
357
    assert U.allclose(new_repr[1], old_repr.max(0)[0])
358
359

    # consecutive send and send_on
360
361
362
363
364
365
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send((2, 1), _message_a)
    g.send([0, 1], message_func=_message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
366
    assert U.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))
367
368

    # consecutive send_on
369
370
371
372
373
374
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send(0, message_func=_message_a)
    g.send(1, message_func=_message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
375
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))
376
377

    # send_and_recv_on
378
379
380
381
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send_and_recv([0, 2, 3], message_func=_message_a, reduce_func=_reduce)
    new_repr = g.ndata['a']
382
383
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
    assert U.allclose(new_repr[[0, 2]], th.zeros(2, 5))
384

385
386
387
388
389
390
391
392
def test_dynamic_addition():
    N = 3
    D = 1

    g = DGLGraph()

    # Test node addition
    g.add_nodes(N)
393
394
    g.ndata.update({'h1': th.randn(N, D),
                    'h2': th.randn(N, D)})
395
    g.add_nodes(3)
396
    assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3
397
398
399
400

    # Test edge addition
    g.add_edge(0, 1)
    g.add_edge(1, 0)
401
402
403
    g.edata.update({'h1': th.randn(2, D),
                    'h2': th.randn(2, D)})
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2
404
405

    g.add_edges([0, 2], [2, 0])
406
407
    g.edata['h1'] = th.randn(4, D)
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4
408
409

    g.add_edge(1, 2)
410
411
    g.edges[4].data['h1'] = th.randn(1, D)
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5
412

413

414
415
if __name__ == '__main__':
    test_batch_setter_getter()
416
    test_batch_setter_autograd()
417
    test_batch_send()
418
    test_batch_recv()
419
    test_apply_nodes()
420
    test_apply_edges()
421
    test_update_routines()
422
    test_reduce_0deg()
423
    test_pull_0deg()
424
    test_send_multigraph()
425
    test_dynamic_addition()