"examples/vscode:/vscode.git/clone" did not exist on "98c1117d00edd38d72610d6a87c0c8d706873863"
test_basics.py 11.4 KB
Newer Older
1
import torch as th
2
3
from torch.autograd import Variable
import numpy as np
4
from dgl.graph import DGLGraph
5
import utils as U
6
7
8
9

D = 5
reduce_msg_shapes = set()

10
11
12
13
def check_eq(a, b):
    assert a.shape == b.shape
    assert th.sum(a == b) == int(np.prod(list(a.shape)))

14
15
16
17
def message_func(edges):
    assert len(edges.src['h'].shape) == 2
    assert edges.src['h'].shape[1] == D
    return {'m' : edges.src['h']}
18

19
20
def reduce_func(nodes):
    msgs = nodes.mailbox['m']
Minjie Wang's avatar
Minjie Wang committed
21
22
23
    reduce_msg_shapes.add(tuple(msgs.shape))
    assert len(msgs.shape) == 3
    assert msgs.shape[2] == D
Minjie Wang's avatar
Minjie Wang committed
24
    return {'accum' : th.sum(msgs, 1)}
Minjie Wang's avatar
Minjie Wang committed
25

26
27
def apply_node_func(nodes):
    return {'h' : nodes.data['h'] + nodes.data['accum']}
Minjie Wang's avatar
Minjie Wang committed
28

29
def generate_graph(grad=False):
30
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
31
    g.add_nodes(10) # 10 nodes.
32
    # create a graph where 0 is the source and 9 is the sink
Minjie Wang's avatar
Minjie Wang committed
33
    # 17 edges
34
35
36
37
38
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
Minjie Wang's avatar
Minjie Wang committed
39
    ncol = Variable(th.randn(10, D), requires_grad=grad)
Minjie Wang's avatar
Minjie Wang committed
40
    ecol = Variable(th.randn(17, D), requires_grad=grad)
41
42
    g.ndata['h'] = ncol
    g.edata['w'] = ecol
Minjie Wang's avatar
Minjie Wang committed
43
    g.set_n_initializer(lambda shape, dtype : th.zeros(shape))
44
    g.set_e_initializer(lambda shape, dtype : th.zeros(shape))
45
46
47
48
49
50
51
    return g

def test_batch_setter_getter():
    def _pfc(x):
        return list(x.numpy()[:,0])
    g = generate_graph()
    # set all nodes
52
    g.ndata['h'] = th.zeros((10, D))
53
    assert U.allclose(g.ndata['h'], th.zeros((10, D)))
Minjie Wang's avatar
Minjie Wang committed
54
    # pop nodes
55
    old_len = len(g.ndata)
Minjie Wang's avatar
Minjie Wang committed
56
    assert _pfc(g.pop_n_repr('h')) == [0.] * 10
57
58
    assert len(g.ndata) == old_len - 1
    g.ndata['h'] = th.zeros((10, D))
59
60
    # set partial nodes
    u = th.tensor([1, 3, 5])
61
62
    g.nodes[u].data['h'] = th.ones((3, D))
    assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
63
64
    # get partial nodes
    u = th.tensor([1, 2, 3])
65
    assert _pfc(g.nodes[u].data['h']) == [1., 0., 1.]
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
    0, 2, 2
    2, 9, 3
    0, 3, 4
    3, 9, 5
    0, 4, 6
    4, 9, 7
    0, 5, 8
    5, 9, 9
    0, 6, 10
    6, 9, 11
    0, 7, 12
    7, 9, 13
    0, 8, 14
    8, 9, 15
    9, 0, 16
    '''
    # set all edges
88
89
    g.edata['l'] = th.zeros((17, D))
    assert _pfc(g.edata['l']) == [0.] * 17
Minjie Wang's avatar
Minjie Wang committed
90
    # pop edges
91
    old_len = len(g.edata)
Minjie Wang's avatar
Minjie Wang committed
92
    assert _pfc(g.pop_e_repr('l')) == [0.] * 17
93
94
    assert len(g.edata) == old_len - 1
    g.edata['l'] = th.zeros((17, D))
Minjie Wang's avatar
Minjie Wang committed
95
96
97
    # set partial edges (many-many)
    u = th.tensor([0, 0, 2, 5, 9])
    v = th.tensor([1, 3, 9, 9, 0])
98
    g.edges[u, v].data['l'] = th.ones((5, D))
Minjie Wang's avatar
Minjie Wang committed
99
100
    truth = [0.] * 17
    truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
101
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
102
103
104
    # set partial edges (many-one)
    u = th.tensor([3, 4, 6])
    v = th.tensor([9])
105
    g.edges[u, v].data['l'] = th.ones((3, D))
Minjie Wang's avatar
Minjie Wang committed
106
    truth[5] = truth[7] = truth[11] = 1.
107
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
108
109
110
    # set partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([4, 5, 6])
111
    g.edges[u, v].data['l'] = th.ones((3, D))
Minjie Wang's avatar
Minjie Wang committed
112
    truth[6] = truth[8] = truth[10] = 1.
113
    assert _pfc(g.edata['l']) == truth
Minjie Wang's avatar
Minjie Wang committed
114
115
116
    # get partial edges (many-many)
    u = th.tensor([0, 6, 0])
    v = th.tensor([6, 9, 7])
117
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Minjie Wang's avatar
Minjie Wang committed
118
119
120
    # get partial edges (many-one)
    u = th.tensor([5, 6, 7])
    v = th.tensor([9])
121
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Minjie Wang's avatar
Minjie Wang committed
122
123
124
    # get partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([3, 4, 5])
125
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
126

127
128
def test_batch_setter_autograd():
    g = generate_graph(grad=True)
129
    h1 = g.ndata['h']
130
131
132
    # partial set
    v = th.tensor([1, 2, 8])
    hh = Variable(th.zeros((len(v), D)), requires_grad=True)
133
134
    g.nodes[v].data['h'] = hh
    h2 = g.ndata['h']
135
136
137
138
    h2.backward(th.ones((10, D)) * 2)
    check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
    check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))

139
140
def test_batch_send():
    g = generate_graph()
141
142
143
    def _fmsg(edges):
        assert edges.src['h'].shape == (5, D)
        return {'m' : edges.src['h']}
Minjie Wang's avatar
Minjie Wang committed
144
    g.register_message_func(_fmsg)
145
    # many-many send
146
147
    u = th.tensor([0, 0, 0, 0, 0])
    v = th.tensor([1, 2, 3, 4, 5])
148
    g.send((u, v))
149
    # one-many send
150
151
    u = th.tensor([0])
    v = th.tensor([1, 2, 3, 4, 5])
152
    g.send((u, v))
153
    # many-one send
154
155
    u = th.tensor([1, 2, 3, 4, 5])
    v = th.tensor([9])
156
    g.send((u, v))
157

158
def test_batch_recv():
Minjie Wang's avatar
Minjie Wang committed
159
    # basic recv test
160
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
161
162
163
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
Minjie Wang's avatar
Minjie Wang committed
164
165
166
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
    reduce_msg_shapes.clear()
167
    g.send((u, v))
Minjie Wang's avatar
Minjie Wang committed
168
169
170
171
    g.recv(th.unique(v))
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

172
def test_apply_edges():
173
174
175
    def _upd(edges):
        return {'w' : edges.data['w'] * 2}
    g = generate_graph()
176
    g.register_apply_edge_func(_upd)
177
    old = g.edata['w']
178
    g.apply_edges()
179
    assert U.allclose(old * 2, g.edata['w'])
180
181
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
182
    g.apply_edges(lambda edges : {'w' : edges.data['w'] * 0.}, (u, v))
183
    eid = g.edge_ids(u, v)
184
    assert U.allclose(g.edata['w'][eid], th.zeros((6, D)))
185

186
187
def test_update_routines():
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
188
189
190
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
191

192
    # send_and_recv
193
194
195
    reduce_msg_shapes.clear()
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
196
    g.send_and_recv((u, v))
197
198
199
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

200
    # pull
201
202
    v = th.tensor([1, 2, 3, 9])
    reduce_msg_shapes.clear()
203
    g.pull(v)
204
205
206
    assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
    reduce_msg_shapes.clear()

207
    # push
208
209
    v = th.tensor([0, 1, 2, 3])
    reduce_msg_shapes.clear()
210
    g.push(v)
211
212
213
214
215
216
217
218
219
    assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
    reduce_msg_shapes.clear()

    # update_all
    reduce_msg_shapes.clear()
    g.update_all()
    assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
    reduce_msg_shapes.clear()

220
221
def test_reduce_0deg():
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
222
    g.add_nodes(5)
223
224
225
226
    g.add_edge(1, 0)
    g.add_edge(2, 0)
    g.add_edge(3, 0)
    g.add_edge(4, 0)
227
228
229
230
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
231
    old_repr = th.randn(5, 5)
232
    g.ndata['h'] = old_repr
Minjie Wang's avatar
Minjie Wang committed
233
    g.update_all(_message, _reduce)
234
    new_repr = g.ndata['h']
235

236
237
    assert U.allclose(new_repr[1:], old_repr[1:])
    assert U.allclose(new_repr[0], old_repr.sum(0))
238

239
def test_pull_0deg():
240
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
241
    g.add_nodes(2)
242
    g.add_edge(0, 1)
243
244
245
246
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.mailbox['m'].sum(1)}
247
    old_repr = th.randn(2, 5)
248
    g.ndata['h'] = old_repr
Minjie Wang's avatar
Minjie Wang committed
249

Minjie Wang's avatar
Minjie Wang committed
250
    g.pull(0, _message, _reduce)
251
    new_repr = g.ndata['h']
252
253
    assert U.allclose(new_repr[0], old_repr[0])
    assert U.allclose(new_repr[1], old_repr[1])
Minjie Wang's avatar
Minjie Wang committed
254

Minjie Wang's avatar
Minjie Wang committed
255
    g.pull(1, _message, _reduce)
256
    new_repr = g.ndata['h']
257
    assert U.allclose(new_repr[1], old_repr[0])
258
259

    old_repr = th.randn(2, 5)
260
    g.ndata['h'] = old_repr
Minjie Wang's avatar
Minjie Wang committed
261
    g.pull([0, 1], _message, _reduce)
262
    new_repr = g.ndata['h']
263
264
    assert U.allclose(new_repr[0], old_repr[0])
    assert U.allclose(new_repr[1], old_repr[0])
265

266
267
def _disabled_test_send_twice():
    # TODO(minjie): please re-enable this unittest after the send code problem is fixed.
268
269
270
271
    g = DGLGraph()
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(2, 1)
272
273
274
275
276
277
    def _message_a(edges):
        return {'a': edges.src['a']}
    def _message_b(edges):
        return {'a': edges.src['a'] * 3}
    def _reduce(nodes):
        return {'a': nodes.mailbox['a'].max(1)[0]}
278
279

    old_repr = th.randn(3, 5)
280
281
282
283
284
    g.ndata['a'] = old_repr
    g.send((0, 1), _message_a)
    g.send((0, 1), _message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
285
    assert U.allclose(new_repr[1], old_repr[0] * 3)
286

287
288
289
290
291
    g.ndata['a'] = old_repr
    g.send((0, 1), _message_a)
    g.send((2, 1), _message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
292
    assert U.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])
293
294
295
296
297
298
299
300
301

def test_send_multigraph():
    g = DGLGraph(multigraph=True)
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(2, 1)

302
303
304
305
306
307
    def _message_a(edges):
        return {'a': edges.data['a']}
    def _message_b(edges):
        return {'a': edges.data['a'] * 3}
    def _reduce(nodes):
        return {'a': nodes.mailbox['a'].max(1)[0]}
308
309
310
311
312
313

    def answer(*args):
        return th.stack(args, 0).max(0)[0]

    # send by eid
    old_repr = th.randn(4, 5)
314
315
316
317
318
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send([0, 2], message_func=_message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
319
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))
320

321
322
323
324
325
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send([0, 2, 3], message_func=_message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
326
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
327
328

    # send on multigraph
329
330
331
332
333
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send(([0, 2], [1, 1]), _message_a)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
334
    assert U.allclose(new_repr[1], old_repr.max(0)[0])
335
336

    # consecutive send and send_on
337
338
339
340
341
342
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send((2, 1), _message_a)
    g.send([0, 1], message_func=_message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
343
    assert U.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))
344
345

    # consecutive send_on
346
347
348
349
350
351
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send(0, message_func=_message_a)
    g.send(1, message_func=_message_b)
    g.recv(1, _reduce)
    new_repr = g.ndata['a']
352
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))
353
354

    # send_and_recv_on
355
356
357
358
    g.ndata['a'] = th.zeros(3, 5)
    g.edata['a'] = old_repr
    g.send_and_recv([0, 2, 3], message_func=_message_a, reduce_func=_reduce)
    new_repr = g.ndata['a']
359
360
    assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
    assert U.allclose(new_repr[[0, 2]], th.zeros(2, 5))
361

362
363
364
365
366
367
368
369
def test_dynamic_addition():
    N = 3
    D = 1

    g = DGLGraph()

    # Test node addition
    g.add_nodes(N)
370
371
    g.ndata.update({'h1': th.randn(N, D),
                    'h2': th.randn(N, D)})
372
    g.add_nodes(3)
373
    assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3
374
375
376
377

    # Test edge addition
    g.add_edge(0, 1)
    g.add_edge(1, 0)
378
379
380
    g.edata.update({'h1': th.randn(2, D),
                    'h2': th.randn(2, D)})
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2
381
382

    g.add_edges([0, 2], [2, 0])
383
384
    g.edata['h1'] = th.randn(4, D)
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4
385
386

    g.add_edge(1, 2)
387
388
    g.edges[4].data['h1'] = th.randn(1, D)
    assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5
389

390

391
392
if __name__ == '__main__':
    test_batch_setter_getter()
393
    test_batch_setter_autograd()
394
    test_batch_send()
395
    test_batch_recv()
396
    test_apply_edges()
397
    test_update_routines()
398
    test_reduce_0deg()
399
    test_pull_0deg()
400
    test_send_multigraph()
401
    test_dynamic_addition()