test_basics.py 11.4 KB
Newer Older
1
import torch as th
2
3
from torch.autograd import Variable
import numpy as np
4
5
6
7
8
from dgl.graph import DGLGraph

D = 5
reduce_msg_shapes = set()

9
10
11
12
def check_eq(a, b):
    assert a.shape == b.shape
    assert th.sum(a == b) == int(np.prod(list(a.shape)))

13
14
15
16
def message_func(edges):
    assert len(edges.src['h'].shape) == 2
    assert edges.src['h'].shape[1] == D
    return {'m' : edges.src['h']}
17

18
19
def reduce_func(nodes):
    msgs = nodes.mailbox['m']
Minjie Wang's avatar
Minjie Wang committed
20
21
22
    reduce_msg_shapes.add(tuple(msgs.shape))
    assert len(msgs.shape) == 3
    assert msgs.shape[2] == D
Minjie Wang's avatar
Minjie Wang committed
23
    return {'accum' : th.sum(msgs, 1)}
Minjie Wang's avatar
Minjie Wang committed
24

25
26
def apply_node_func(nodes):
    return {'h' : nodes.data['h'] + nodes.data['accum']}
Minjie Wang's avatar
Minjie Wang committed
27

28
def generate_graph(grad=False):
29
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
30
    g.add_nodes(10) # 10 nodes.
31
    # create a graph where 0 is the source and 9 is the sink
Minjie Wang's avatar
Minjie Wang committed
32
    # 17 edges
33
34
35
36
37
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
Minjie Wang's avatar
Minjie Wang committed
38
    ncol = Variable(th.randn(10, D), requires_grad=grad)
Minjie Wang's avatar
Minjie Wang committed
39
    ecol = Variable(th.randn(17, D), requires_grad=grad)
40
41
    g.ndata['h'] = ncol
    g.edata['w'] = ecol
Minjie Wang's avatar
Minjie Wang committed
42
    g.set_n_initializer(lambda shape, dtype : th.zeros(shape))
43
    g.set_e_initializer(lambda shape, dtype : th.zeros(shape))
44
45
46
47
48
49
50
    return g

def test_batch_setter_getter():
    def _pfc(x):
        return list(x.numpy()[:,0])
    g = generate_graph()
    # set all nodes
51
52
    g.ndata['h'] = th.zeros((10, D))
    assert th.allclose(g.ndata['h'], th.zeros((10, D)))
Minjie Wang's avatar
Minjie Wang committed
53
    # pop nodes
54
    old_len = len(g.ndata)
Minjie Wang's avatar
Minjie Wang committed
55
    assert _pfc(g.pop_n_repr('h')) == [0.] * 10
56
57
    assert len(g.ndata) == old_len - 1
    g.ndata['h'] = th.zeros((10, D))
58
59
    # set partial nodes
    u = th.tensor([1, 3, 5])
60
61
    g.nodes[u].data['h'] = th.ones((3, D))
    assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
62
63
    # get partial nodes
    u = th.tensor([1, 2, 3])
64
    assert _pfc(g.nodes[u].data['h']) == [1., 0., 1.]
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
    0, 2, 2
    2, 9, 3
    0, 3, 4
    3, 9, 5
    0, 4, 6
    4, 9, 7
    0, 5, 8
    5, 9, 9
    0, 6, 10
    6, 9, 11
    0, 7, 12
    7, 9, 13
    0, 8, 14
    8, 9, 15
    9, 0, 16
    '''
    # set all edges
87
88
    g.edata['l'] = th.zeros((17, D))
    assert _pfc(g.edata['l']) == [0.] * 17
Minjie Wang's avatar
Minjie Wang committed
89
    # pop edges
90
    old_len = len(g.edata)
Minjie Wang's avatar
Minjie Wang committed
91
    assert _pfc(g.pop_e_repr('l')) == [0.] * 17
92
93
    assert len(g.edata) == old_len - 1
    g.edata['l'] = th.zeros((17, D))
Minjie Wang's avatar
Minjie Wang committed
94
95
96
    # set partial edges (many-many)
    u = th.tensor([0, 0, 2, 5, 9])
    v = th.tensor([1, 3, 9, 9, 0])
97
    g.edges[u, v].data['l'] = th.ones((5, D))
Minjie Wang's avatar
Minjie Wang committed
98
99
    truth = [0.] * 17
    truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
100
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
101
102
103
    # set partial edges (many-one)
    u = th.tensor([3, 4, 6])
    v = th.tensor([9])
104
    g.edges[u, v].data['l'] = th.ones((3, D))
Minjie Wang's avatar
Minjie Wang committed
105
    truth[5] = truth[7] = truth[11] = 1.
106
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
107
108
109
    # set partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([4, 5, 6])
110
    g.edges[u, v].data['l'] = th.ones((3, D))
Minjie Wang's avatar
Minjie Wang committed
111
    truth[6] = truth[8] = truth[10] = 1.
112
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
113
114
115
    # get partial edges (many-many)
    u = th.tensor([0, 6, 0])
    v = th.tensor([6, 9, 7])
116
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Minjie Wang's avatar
Minjie Wang committed
117
118
119
    # get partial edges (many-one)
    u = th.tensor([5, 6, 7])
    v = th.tensor([9])
120
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Minjie Wang's avatar
Minjie Wang committed
121
122
123
    # get partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([3, 4, 5])
124
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
125

126
127
def test_batch_setter_autograd():
    g = generate_graph(grad=True)
128
    h1 = g.ndata['h']
129
130
131
    # partial set
    v = th.tensor([1, 2, 8])
    hh = Variable(th.zeros((len(v), D)), requires_grad=True)
132
133
    g.nodes[v].data['h'] = hh
    h2 = g.ndata['h']
134
135
136
137
    h2.backward(th.ones((10, D)) * 2)
    check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
    check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))

138
139
def test_batch_send():
    g = generate_graph()
140
141
142
    def _fmsg(edges):
        assert edges.src['h'].shape == (5, D)
        return {'m' : edges.src['h']}
Minjie Wang's avatar
Minjie Wang committed
143
    g.register_message_func(_fmsg)
144
    # many-many send
145
146
    u = th.tensor([0, 0, 0, 0, 0])
    v = th.tensor([1, 2, 3, 4, 5])
147
    g.send((u, v))
148
    # one-many send
149
150
    u = th.tensor([0])
    v = th.tensor([1, 2, 3, 4, 5])
151
    g.send((u, v))
152
    # many-one send
153
154
    u = th.tensor([1, 2, 3, 4, 5])
    v = th.tensor([9])
155
    g.send((u, v))
156

157
def test_batch_recv():
Minjie Wang's avatar
Minjie Wang committed
158
    # basic recv test
159
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
160
161
162
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
Minjie Wang's avatar
Minjie Wang committed
163
164
165
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
    reduce_msg_shapes.clear()
166
    g.send((u, v))
Minjie Wang's avatar
Minjie Wang committed
167
168
169
170
    g.recv(th.unique(v))
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

171
def test_apply_edges():
172
173
174
    def _upd(edges):
        return {'w' : edges.data['w'] * 2}
    g = generate_graph()
175
    g.register_apply_edge_func(_upd)
176
    old = g.edata['w']
177
    g.apply_edges()
178
179
180
    assert th.allclose(old * 2, g.edata['w'])
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
181
    g.apply_edges(lambda edges : {'w' : edges.data['w'] * 0.}, (u, v))
182
183
184
    eid = g.edge_ids(u, v)
    assert th.allclose(g.edata['w'][eid], th.zeros((6, D)))

185
186
def test_update_routines():
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
187
188
189
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
190

191
    # send_and_recv
192
193
194
    reduce_msg_shapes.clear()
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
195
    g.send_and_recv((u, v))
196
197
198
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

199
    # pull
200
201
    v = th.tensor([1, 2, 3, 9])
    reduce_msg_shapes.clear()
202
    g.pull(v)
203
204
205
    assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
    reduce_msg_shapes.clear()

206
    # push
207
208
    v = th.tensor([0, 1, 2, 3])
    reduce_msg_shapes.clear()
209
    g.push(v)
210
211
212
213
214
215
216
217
218
    assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
    reduce_msg_shapes.clear()

    # update_all
    reduce_msg_shapes.clear()
    g.update_all()
    assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
    reduce_msg_shapes.clear()

219
220
def test_reduce_0deg():
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
221
    g.add_nodes(5)
222
223
224
225
    g.add_edge(1, 0)
    g.add_edge(2, 0)
    g.add_edge(3, 0)
    g.add_edge(4, 0)
226
227
228
229
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
230
    old_repr = th.randn(5, 5)
231
    g.ndata['h'] = old_repr
Minjie Wang's avatar
Minjie Wang committed
232
    g.update_all(_message, _reduce)
233
    new_repr = g.ndata['h']
234
235

    assert th.allclose(new_repr[1:], old_repr[1:])
GaiYu0's avatar
GaiYu0 committed
236
    assert th.allclose(new_repr[0], old_repr.sum(0), rtol=1e-3, atol=1e-3)
237

238
def test_pull_0deg():
239
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
240
    g.add_nodes(2)
241
    g.add_edge(0, 1)
242
243
244
245
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.mailbox['m'].sum(1)}
246
    old_repr = th.randn(2, 5)
247
    g.ndata['h'] = old_repr
Minjie Wang's avatar
Minjie Wang committed
248

Minjie Wang's avatar
Minjie Wang committed
249
    g.pull(0, _message, _reduce)
250
    new_repr = g.ndata['h']
251
    assert th.allclose(new_repr[0], old_repr[0])
252
    assert th.allclose(new_repr[1], old_repr[1])
Minjie Wang's avatar
Minjie Wang committed
253

Minjie Wang's avatar
Minjie Wang committed
254
    g.pull(1, _message, _reduce)
255
    new_repr = g.ndata['h']
256
    assert th.allclose(new_repr[1], old_repr[0])
257
258

    old_repr = th.randn(2, 5)
259
    g.ndata['h'] = old_repr
Minjie Wang's avatar
Minjie Wang committed
260
    g.pull([0, 1], _message, _reduce)
261
    new_repr = g.ndata['h']
262
    assert th.allclose(new_repr[0], old_repr[0])
263
264
    assert th.allclose(new_repr[1], old_repr[0])

265
266
def _disabled_test_send_twice():
    # TODO(minjie): please re-enable this unittest after the send code problem is fixed.
267
268
269
270
    g = DGLGraph()
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(2, 1)
271
272
273
274
275
276
    def _message_a(edges):
        return {'a': edges.src['a']}
    def _message_b(edges):
        return {'a': edges.src['a'] * 3}
    def _reduce(nodes):
        return {'a': nodes.mailbox['a'].max(1)[0]}
277
278

    old_repr = th.randn(3, 5)
279
280
281
282
283
    g.ndata['a'] = old_repr
    g.send((0, 1), _message_a)
    g.send((0, 1), _message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
284
285
    assert th.allclose(new_repr[1], old_repr[0] * 3)

286
287
288
289
290
    g.ndata['a'] = old_repr
    g.send((0, 1), _message_a)
    g.send((2, 1), _message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
291
292
293
294
295
296
297
298
299
300
    assert th.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])

def test_send_multigraph():
    g = DGLGraph(multigraph=True)
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(2, 1)

301
302
303
304
305
306
    def _message_a(edges):
        return {'a': edges.data['a']}
    def _message_b(edges):
        return {'a': edges.data['a'] * 3}
    def _reduce(nodes):
        return {'a': nodes.mailbox['a'].max(1)[0]}
307
308
309
310
311
312

    def answer(*args):
        return th.stack(args, 0).max(0)[0]

    # send by eid
    old_repr = th.randn(4, 5)
313
314
315
316
317
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send([0, 2], message_func=_message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
318
319
    assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))

320
321
322
323
324
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send([0, 2, 3], message_func=_message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
325
326
327
    assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))

    # send on multigraph
328
329
330
331
332
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send(([0, 2], [1, 1]), _message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
333
334
335
    assert th.allclose(new_repr[1], old_repr.max(0)[0])

    # consecutive send and send_on
336
337
338
339
340
341
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send((2, 1), _message_a)
    g.send([0, 1], message_func=_message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
342
343
344
    assert th.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))

    # consecutive send_on
345
346
347
348
349
350
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send(0, message_func=_message_a)
    g.send(1, message_func=_message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
351
352
353
    assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))

    # send_and_recv_on
354
355
356
357
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send_and_recv([0, 2, 3], message_func=_message_a, reduce_func=_reduce)
    new_repr = g.ndata['a']
358
359
360
    assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
    assert th.allclose(new_repr[[0, 2]], th.zeros(2, 5))

361
362
363
364
365
366
367
368
def test_dynamic_addition():
    N = 3
    D = 1

    g = DGLGraph()

    # Test node addition
    g.add_nodes(N)
369
370
    g.ndata.update({'h1': th.randn(N, D),
                    'h2': th.randn(N, D)})
371
    g.add_nodes(3)
372
    assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3
373
374
375
376

    # Test edge addition
    g.add_edge(0, 1)
    g.add_edge(1, 0)
377
378
379
    g.edata.update({'h1': th.randn(2, D),
                    'h2': th.randn(2, D)})
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2
380
381

    g.add_edges([0, 2], [2, 0])
382
383
    g.edata['h1'] = th.randn(4, D)
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4
384
385

    g.add_edge(1, 2)
386
387
    g.edges[4].data['h1'] = th.randn(1, D)
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5
388

389

390
391
if __name__ == '__main__':
    test_batch_setter_getter()
392
    test_batch_setter_autograd()
393
    test_batch_send()
394
    test_batch_recv()
395
    test_apply_edges()
396
    test_update_routines()
397
    test_reduce_0deg()
398
    test_pull_0deg()
399
    test_send_multigraph()
400
    test_dynamic_addition()