test_basics.py 11.6 KB
Newer Older
1
import torch as th
2
3
from torch.autograd import Variable
import numpy as np
4
from dgl.graph import DGLGraph
5
import utils as U
6
7
8
9

D = 5
reduce_msg_shapes = set()

10
11
12
13
def check_eq(a, b):
    assert a.shape == b.shape
    assert th.sum(a == b) == int(np.prod(list(a.shape)))

14
15
16
17
def message_func(edges):
    assert len(edges.src['h'].shape) == 2
    assert edges.src['h'].shape[1] == D
    return {'m' : edges.src['h']}
18

19
20
def reduce_func(nodes):
    msgs = nodes.mailbox['m']
Minjie Wang's avatar
Minjie Wang committed
21
22
23
    reduce_msg_shapes.add(tuple(msgs.shape))
    assert len(msgs.shape) == 3
    assert msgs.shape[2] == D
Minjie Wang's avatar
Minjie Wang committed
24
    return {'accum' : th.sum(msgs, 1)}
Minjie Wang's avatar
Minjie Wang committed
25

26
27
def apply_node_func(nodes):
    return {'h' : nodes.data['h'] + nodes.data['accum']}
Minjie Wang's avatar
Minjie Wang committed
28

29
def generate_graph(grad=False):
30
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
31
    g.add_nodes(10) # 10 nodes.
32
    # create a graph where 0 is the source and 9 is the sink
Minjie Wang's avatar
Minjie Wang committed
33
    # 17 edges
34
35
36
37
38
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
Minjie Wang's avatar
Minjie Wang committed
39
    ncol = Variable(th.randn(10, D), requires_grad=grad)
Minjie Wang's avatar
Minjie Wang committed
40
    ecol = Variable(th.randn(17, D), requires_grad=grad)
41
42
    g.ndata['h'] = ncol
    g.edata['w'] = ecol
43
44
    g.set_n_initializer(lambda shape, dtype, ctx : th.zeros(shape, dtype=dtype, device=ctx))
    g.set_e_initializer(lambda shape, dtype, ctx : th.zeros(shape, dtype=dtype, device=ctx))
45
46
47
48
49
50
51
    return g

def test_batch_setter_getter():
    def _pfc(x):
        return list(x.numpy()[:,0])
    g = generate_graph()
    # set all nodes
52
    g.ndata['h'] = th.zeros((10, D))
53
    assert U.allclose(g.ndata['h'], th.zeros((10, D)))
Minjie Wang's avatar
Minjie Wang committed
54
    # pop nodes
55
    old_len = len(g.ndata)
Minjie Wang's avatar
Minjie Wang committed
56
    assert _pfc(g.pop_n_repr('h')) == [0.] * 10
57
58
    assert len(g.ndata) == old_len - 1
    g.ndata['h'] = th.zeros((10, D))
59
60
    # set partial nodes
    u = th.tensor([1, 3, 5])
61
62
    g.nodes[u].data['h'] = th.ones((3, D))
    assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
63
64
    # get partial nodes
    u = th.tensor([1, 2, 3])
65
    assert _pfc(g.nodes[u].data['h']) == [1., 0., 1.]
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
    0, 2, 2
    2, 9, 3
    0, 3, 4
    3, 9, 5
    0, 4, 6
    4, 9, 7
    0, 5, 8
    5, 9, 9
    0, 6, 10
    6, 9, 11
    0, 7, 12
    7, 9, 13
    0, 8, 14
    8, 9, 15
    9, 0, 16
    '''
    # set all edges
88
89
    g.edata['l'] = th.zeros((17, D))
    assert _pfc(g.edata['l']) == [0.] * 17
Minjie Wang's avatar
Minjie Wang committed
90
    # pop edges
91
    old_len = len(g.edata)
Minjie Wang's avatar
Minjie Wang committed
92
    assert _pfc(g.pop_e_repr('l')) == [0.] * 17
93
94
    assert len(g.edata) == old_len - 1
    g.edata['l'] = th.zeros((17, D))
Minjie Wang's avatar
Minjie Wang committed
95
96
97
    # set partial edges (many-many)
    u = th.tensor([0, 0, 2, 5, 9])
    v = th.tensor([1, 3, 9, 9, 0])
98
    g.edges[u, v].data['l'] = th.ones((5, D))
Minjie Wang's avatar
Minjie Wang committed
99
100
    truth = [0.] * 17
    truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
101
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
102
103
104
    # set partial edges (many-one)
    u = th.tensor([3, 4, 6])
    v = th.tensor([9])
105
    g.edges[u, v].data['l'] = th.ones((3, D))
Minjie Wang's avatar
Minjie Wang committed
106
    truth[5] = truth[7] = truth[11] = 1.
107
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
108
109
110
    # set partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([4, 5, 6])
111
    g.edges[u, v].data['l'] = th.ones((3, D))
Minjie Wang's avatar
Minjie Wang committed
112
    truth[6] = truth[8] = truth[10] = 1.
113
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
114
115
116
    # get partial edges (many-many)
    u = th.tensor([0, 6, 0])
    v = th.tensor([6, 9, 7])
117
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Minjie Wang's avatar
Minjie Wang committed
118
119
120
    # get partial edges (many-one)
    u = th.tensor([5, 6, 7])
    v = th.tensor([9])
121
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Minjie Wang's avatar
Minjie Wang committed
122
123
124
    # get partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([3, 4, 5])
125
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
126

127
128
def test_batch_setter_autograd():
    g = generate_graph(grad=True)
129
    h1 = g.ndata['h']
130
131
132
    # partial set
    v = th.tensor([1, 2, 8])
    hh = Variable(th.zeros((len(v), D)), requires_grad=True)
133
134
    g.nodes[v].data['h'] = hh
    h2 = g.ndata['h']
135
136
137
138
    h2.backward(th.ones((10, D)) * 2)
    check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
    check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))

139
140
def test_batch_send():
    g = generate_graph()
141
142
143
    def _fmsg(edges):
        assert edges.src['h'].shape == (5, D)
        return {'m' : edges.src['h']}
Minjie Wang's avatar
Minjie Wang committed
144
    g.register_message_func(_fmsg)
145
    # many-many send
146
147
    u = th.tensor([0, 0, 0, 0, 0])
    v = th.tensor([1, 2, 3, 4, 5])
148
    g.send((u, v))
149
    # one-many send
150
151
    u = th.tensor([0])
    v = th.tensor([1, 2, 3, 4, 5])
152
    g.send((u, v))
153
    # many-one send
154
155
    u = th.tensor([1, 2, 3, 4, 5])
    v = th.tensor([9])
156
    g.send((u, v))
157

158
def test_batch_recv():
Minjie Wang's avatar
Minjie Wang committed
159
    # basic recv test
160
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
161
162
163
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
Minjie Wang's avatar
Minjie Wang committed
164
165
166
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
    reduce_msg_shapes.clear()
167
    g.send((u, v))
Minjie Wang's avatar
Minjie Wang committed
168
169
170
171
    g.recv(th.unique(v))
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

172
def test_apply_edges():
173
174
175
    def _upd(edges):
        return {'w' : edges.data['w'] * 2}
    g = generate_graph()
176
    g.register_apply_edge_func(_upd)
177
    old = g.edata['w']
178
    g.apply_edges()
179
    assert U.allclose(old * 2, g.edata['w'])
180
181
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
182
    g.apply_edges(lambda edges : {'w' : edges.data['w'] * 0.}, (u, v))
183
    eid = g.edge_ids(u, v)
184
    assert U.allclose(g.edata['w'][eid], th.zeros((6, D)))
185

186
187
def test_update_routines():
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
188
189
190
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
191

192
    # send_and_recv
193
    reduce_msg_shapes.clear()
194
195
    u = [0, 0, 0, 4, 5, 6]
    v = [1, 2, 3, 9, 9, 9]
196
    g.send_and_recv((u, v))
197
198
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()
199
200
201
202
203
    try:
        g.send_and_recv([u, v])
        assert False
    except ValueError:
        pass
204

205
    # pull
206
207
    v = th.tensor([1, 2, 3, 9])
    reduce_msg_shapes.clear()
208
    g.pull(v)
209
210
211
    assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
    reduce_msg_shapes.clear()

212
    # push
213
214
    v = th.tensor([0, 1, 2, 3])
    reduce_msg_shapes.clear()
215
    g.push(v)
216
217
218
219
220
221
222
223
224
    assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
    reduce_msg_shapes.clear()

    # update_all
    reduce_msg_shapes.clear()
    g.update_all()
    assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
    reduce_msg_shapes.clear()

225
226
def test_reduce_0deg():
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
227
    g.add_nodes(5)
228
229
230
231
    g.add_edge(1, 0)
    g.add_edge(2, 0)
    g.add_edge(3, 0)
    g.add_edge(4, 0)
232
233
234
235
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
236
    old_repr = th.randn(5, 5)
237
    g.ndata['h'] = old_repr
Minjie Wang's avatar
Minjie Wang committed
238
    g.update_all(_message, _reduce)
239
    new_repr = g.ndata['h']
240

241
242
    assert U.allclose(new_repr[1:], old_repr[1:])
    assert U.allclose(new_repr[0], old_repr.sum(0))
243

244
def test_pull_0deg():
245
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
246
    g.add_nodes(2)
247
    g.add_edge(0, 1)
248
249
250
251
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.mailbox['m'].sum(1)}
252
    old_repr = th.randn(2, 5)
253
    g.ndata['h'] = old_repr
Minjie Wang's avatar
Minjie Wang committed
254

Minjie Wang's avatar
Minjie Wang committed
255
    g.pull(0, _message, _reduce)
256
    new_repr = g.ndata['h']
257
258
    assert U.allclose(new_repr[0], old_repr[0])
    assert U.allclose(new_repr[1], old_repr[1])
Minjie Wang's avatar
Minjie Wang committed
259

Minjie Wang's avatar
Minjie Wang committed
260
    g.pull(1, _message, _reduce)
261
    new_repr = g.ndata['h']
262
    assert U.allclose(new_repr[1], old_repr[0])
263
264

    old_repr = th.randn(2, 5)
265
    g.ndata['h'] = old_repr
Minjie Wang's avatar
Minjie Wang committed
266
    g.pull([0, 1], _message, _reduce)
267
    new_repr = g.ndata['h']
268
269
    assert U.allclose(new_repr[0], old_repr[0])
    assert U.allclose(new_repr[1], old_repr[0])
270

271
272
def _disabled_test_send_twice():
    # TODO(minjie): please re-enable this unittest after the send code problem is fixed.
273
274
275
276
    g = DGLGraph()
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(2, 1)
277
278
279
280
281
282
    def _message_a(edges):
        return {'a': edges.src['a']}
    def _message_b(edges):
        return {'a': edges.src['a'] * 3}
    def _reduce(nodes):
        return {'a': nodes.mailbox['a'].max(1)[0]}
283
284

    old_repr = th.randn(3, 5)
285
286
287
288
289
    g.ndata['a'] = old_repr
    g.send((0, 1), _message_a)
    g.send((0, 1), _message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
290
    assert U.allclose(new_repr[1], old_repr[0] * 3)
291

292
293
294
295
296
    g.ndata['a'] = old_repr
    g.send((0, 1), _message_a)
    g.send((2, 1), _message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
297
    assert U.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])
298
299
300
301
302
303
304
305
306

def test_send_multigraph():
    g = DGLGraph(multigraph=True)
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(2, 1)

307
308
309
310
311
312
    def _message_a(edges):
        return {'a': edges.data['a']}
    def _message_b(edges):
        return {'a': edges.data['a'] * 3}
    def _reduce(nodes):
        return {'a': nodes.mailbox['a'].max(1)[0]}
313
314
315
316
317
318

    def answer(*args):
        return th.stack(args, 0).max(0)[0]

    # send by eid
    old_repr = th.randn(4, 5)
319
320
321
322
323
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send([0, 2], message_func=_message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
324
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))
325

326
327
328
329
330
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send([0, 2, 3], message_func=_message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
331
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
332
333

    # send on multigraph
334
335
336
337
338
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send(([0, 2], [1, 1]), _message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
339
    assert U.allclose(new_repr[1], old_repr.max(0)[0])
340
341

    # consecutive send and send_on
342
343
344
345
346
347
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send((2, 1), _message_a)
    g.send([0, 1], message_func=_message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
348
    assert U.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))
349
350

    # consecutive send_on
351
352
353
354
355
356
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send(0, message_func=_message_a)
    g.send(1, message_func=_message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
357
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))
358
359

    # send_and_recv_on
360
361
362
363
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send_and_recv([0, 2, 3], message_func=_message_a, reduce_func=_reduce)
    new_repr = g.ndata['a']
364
365
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
    assert U.allclose(new_repr[[0, 2]], th.zeros(2, 5))
366

367
368
369
370
371
372
373
374
def test_dynamic_addition():
    N = 3
    D = 1

    g = DGLGraph()

    # Test node addition
    g.add_nodes(N)
375
376
    g.ndata.update({'h1': th.randn(N, D),
                    'h2': th.randn(N, D)})
377
    g.add_nodes(3)
378
    assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3
379
380
381
382

    # Test edge addition
    g.add_edge(0, 1)
    g.add_edge(1, 0)
383
384
385
    g.edata.update({'h1': th.randn(2, D),
                    'h2': th.randn(2, D)})
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2
386
387

    g.add_edges([0, 2], [2, 0])
388
389
    g.edata['h1'] = th.randn(4, D)
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4
390
391

    g.add_edge(1, 2)
392
393
    g.edges[4].data['h1'] = th.randn(1, D)
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5
394

395

396
397
if __name__ == '__main__':
    test_batch_setter_getter()
398
    test_batch_setter_autograd()
399
    test_batch_send()
400
    test_batch_recv()
401
    test_apply_edges()
402
    test_update_routines()
403
    test_reduce_0deg()
404
    test_pull_0deg()
405
    test_send_multigraph()
406
    test_dynamic_addition()