test_basics.py 10.4 KB
Newer Older
1
import torch as th
2
3
from torch.autograd import Variable
import numpy as np
4
5
6
7
8
from dgl.graph import DGLGraph

D = 5
reduce_msg_shapes = set()

9
10
11
12
def check_eq(a, b):
    assert a.shape == b.shape
    assert th.sum(a == b) == int(np.prod(list(a.shape)))

13
14
15
16
17
18
def message_func(src, edge):
    assert len(src['h'].shape) == 2
    assert src['h'].shape[1] == D
    return {'m' : src['h']}

def reduce_func(node, msgs):
Minjie Wang's avatar
Minjie Wang committed
19
20
21
22
    msgs = msgs['m']
    reduce_msg_shapes.add(tuple(msgs.shape))
    assert len(msgs.shape) == 3
    assert msgs.shape[2] == D
Minjie Wang's avatar
Minjie Wang committed
23
    return {'accum' : th.sum(msgs, 1)}
Minjie Wang's avatar
Minjie Wang committed
24

25
def apply_node_func(node):
Minjie Wang's avatar
Minjie Wang committed
26
    return {'h' : node['h'] + node['accum']}
Minjie Wang's avatar
Minjie Wang committed
27

28
def generate_graph(grad=False):
29
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
30
    g.add_nodes(10) # 10 nodes.
31
    # create a graph where 0 is the source and 9 is the sink
Minjie Wang's avatar
Minjie Wang committed
32
    # 17 edges
33
34
35
36
37
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
Minjie Wang's avatar
Minjie Wang committed
38
    ncol = Variable(th.randn(10, D), requires_grad=grad)
Minjie Wang's avatar
Minjie Wang committed
39
40
    accumcol = Variable(th.randn(10, D), requires_grad=grad)
    ecol = Variable(th.randn(17, D), requires_grad=grad)
Minjie Wang's avatar
Minjie Wang committed
41
    g.set_n_repr({'h' : ncol})
Minjie Wang's avatar
Minjie Wang committed
42
    g.set_n_initializer(lambda shape, dtype : th.zeros(shape))
43
44
45
46
47
48
49
50
51
    return g

def test_batch_setter_getter():
    def _pfc(x):
        return list(x.numpy()[:,0])
    g = generate_graph()
    # set all nodes
    g.set_n_repr({'h' : th.zeros((10, D))})
    assert _pfc(g.get_n_repr()['h']) == [0.] * 10
Minjie Wang's avatar
Minjie Wang committed
52
    # pop nodes
Minjie Wang's avatar
Minjie Wang committed
53
    old_len = len(g.get_n_repr())
Minjie Wang's avatar
Minjie Wang committed
54
    assert _pfc(g.pop_n_repr('h')) == [0.] * 10
Minjie Wang's avatar
Minjie Wang committed
55
    assert len(g.get_n_repr()) == old_len - 1
Minjie Wang's avatar
Minjie Wang committed
56
    g.set_n_repr({'h' : th.zeros((10, D))})
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    # set partial nodes
    u = th.tensor([1, 3, 5])
    g.set_n_repr({'h' : th.ones((3, D))}, u)
    assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
    # get partial nodes
    u = th.tensor([1, 2, 3])
    assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]

    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
    0, 2, 2
    2, 9, 3
    0, 3, 4
    3, 9, 5
    0, 4, 6
    4, 9, 7
    0, 5, 8
    5, 9, 9
    0, 6, 10
    6, 9, 11
    0, 7, 12
    7, 9, 13
    0, 8, 14
    8, 9, 15
    9, 0, 16
    '''
    # set all edges
    g.set_e_repr({'l' : th.zeros((17, D))})
    assert _pfc(g.get_e_repr()['l']) == [0.] * 17
Minjie Wang's avatar
Minjie Wang committed
88
    # pop edges
Minjie Wang's avatar
Minjie Wang committed
89
    old_len = len(g.get_e_repr())
Minjie Wang's avatar
Minjie Wang committed
90
    assert _pfc(g.pop_e_repr('l')) == [0.] * 17
Minjie Wang's avatar
Minjie Wang committed
91
    assert len(g.get_e_repr()) == old_len - 1
Minjie Wang's avatar
Minjie Wang committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    g.set_e_repr({'l' : th.zeros((17, D))})
    # set partial edges (many-many)
    u = th.tensor([0, 0, 2, 5, 9])
    v = th.tensor([1, 3, 9, 9, 0])
    g.set_e_repr({'l' : th.ones((5, D))}, u, v)
    truth = [0.] * 17
    truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
    assert _pfc(g.get_e_repr()['l']) == truth
    # set partial edges (many-one)
    u = th.tensor([3, 4, 6])
    v = th.tensor([9])
    g.set_e_repr({'l' : th.ones((3, D))}, u, v)
    truth[5] = truth[7] = truth[11] = 1.
    assert _pfc(g.get_e_repr()['l']) == truth
    # set partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([4, 5, 6])
    g.set_e_repr({'l' : th.ones((3, D))}, u, v)
    truth[6] = truth[8] = truth[10] = 1.
    assert _pfc(g.get_e_repr()['l']) == truth
    # get partial edges (many-many)
    u = th.tensor([0, 6, 0])
    v = th.tensor([6, 9, 7])
    assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
    # get partial edges (many-one)
    u = th.tensor([5, 6, 7])
    v = th.tensor([9])
    assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
    # get partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([3, 4, 5])
    assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.]
124

125
126
127
128
129
130
131
132
133
134
135
136
def test_batch_setter_autograd():
    g = generate_graph(grad=True)
    h1 = g.get_n_repr()['h']
    # partial set
    v = th.tensor([1, 2, 8])
    hh = Variable(th.zeros((len(v), D)), requires_grad=True)
    g.set_n_repr({'h' : hh}, v)
    h2 = g.get_n_repr()['h']
    h2.backward(th.ones((10, D)) * 2)
    check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
    check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))

137
138
139
140
141
def test_batch_send():
    g = generate_graph()
    def _fmsg(src, edge):
        assert src['h'].shape == (5, D)
        return {'m' : src['h']}
Minjie Wang's avatar
Minjie Wang committed
142
    g.register_message_func(_fmsg)
143
    # many-many send
144
145
    u = th.tensor([0, 0, 0, 0, 0])
    v = th.tensor([1, 2, 3, 4, 5])
146
147
    g.send(u, v)
    # one-many send
148
149
    u = th.tensor([0])
    v = th.tensor([1, 2, 3, 4, 5])
150
151
    g.send(u, v)
    # many-one send
152
153
    u = th.tensor([1, 2, 3, 4, 5])
    v = th.tensor([9])
154
    g.send(u, v)
155

156
def test_batch_recv():
Minjie Wang's avatar
Minjie Wang committed
157
    # basic recv test
158
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
159
160
161
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
Minjie Wang's avatar
Minjie Wang committed
162
163
164
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
    reduce_msg_shapes.clear()
165
    g.send(u, v)
Minjie Wang's avatar
Minjie Wang committed
166
167
168
169
    g.recv(th.unique(v))
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

170
171
def test_update_routines():
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
172
173
174
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
175

176
    # send_and_recv
177
178
179
    reduce_msg_shapes.clear()
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
180
    g.send_and_recv(u, v)
181
182
183
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

184
    # pull
185
186
    v = th.tensor([1, 2, 3, 9])
    reduce_msg_shapes.clear()
187
    g.pull(v)
188
189
190
    assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
    reduce_msg_shapes.clear()

191
    # push
192
193
    v = th.tensor([0, 1, 2, 3])
    reduce_msg_shapes.clear()
194
    g.push(v)
195
196
197
198
199
200
201
202
203
    assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
    reduce_msg_shapes.clear()

    # update_all
    reduce_msg_shapes.clear()
    g.update_all()
    assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
    reduce_msg_shapes.clear()

204
205
def test_reduce_0deg():
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
206
    g.add_nodes(5)
207
208
209
210
211
    g.add_edge(1, 0)
    g.add_edge(2, 0)
    g.add_edge(3, 0)
    g.add_edge(4, 0)
    def _message(src, edge):
Minjie Wang's avatar
Minjie Wang committed
212
        return {'m' : src['h']}
213
    def _reduce(node, msgs):
Minjie Wang's avatar
Minjie Wang committed
214
        return {'h' : node['h'] + msgs['m'].sum(1)}
215
    old_repr = th.randn(5, 5)
Minjie Wang's avatar
Minjie Wang committed
216
    g.set_n_repr({'h' : old_repr})
Minjie Wang's avatar
Minjie Wang committed
217
    g.update_all(_message, _reduce)
Minjie Wang's avatar
Minjie Wang committed
218
    new_repr = g.get_n_repr()['h']
219
220
221
222

    assert th.allclose(new_repr[1:], old_repr[1:])
    assert th.allclose(new_repr[0], old_repr.sum(0))

223
def test_pull_0deg():
224
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
225
    g.add_nodes(2)
226
227
    g.add_edge(0, 1)
    def _message(src, edge):
Minjie Wang's avatar
Minjie Wang committed
228
        return {'m' : src['h']}
229
    def _reduce(node, msgs):
Minjie Wang's avatar
Minjie Wang committed
230
        return {'h' : msgs['m'].sum(1)}
231
    old_repr = th.randn(2, 5)
Minjie Wang's avatar
Minjie Wang committed
232
233
    g.set_n_repr({'h' : old_repr})

Minjie Wang's avatar
Minjie Wang committed
234
    g.pull(0, _message, _reduce)
Minjie Wang's avatar
Minjie Wang committed
235
    new_repr = g.get_n_repr()['h']
236
    assert th.allclose(new_repr[0], old_repr[0])
237
    assert th.allclose(new_repr[1], old_repr[1])
Minjie Wang's avatar
Minjie Wang committed
238

Minjie Wang's avatar
Minjie Wang committed
239
    g.pull(1, _message, _reduce)
Minjie Wang's avatar
Minjie Wang committed
240
    new_repr = g.get_n_repr()['h']
241
    assert th.allclose(new_repr[1], old_repr[0])
242
243

    old_repr = th.randn(2, 5)
Minjie Wang's avatar
Minjie Wang committed
244
    g.set_n_repr({'h' : old_repr})
Minjie Wang's avatar
Minjie Wang committed
245
    g.pull([0, 1], _message, _reduce)
Minjie Wang's avatar
Minjie Wang committed
246
    new_repr = g.get_n_repr()['h']
247
    assert th.allclose(new_repr[0], old_repr[0])
248
249
    assert th.allclose(new_repr[1], old_repr[0])

250
251
def _disabled_test_send_twice():
    # TODO(minjie): please re-enable this unittest after the send code problem is fixed.
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
    g = DGLGraph()
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(2, 1)
    def _message_a(src, edge):
        return {'a': src['a']}
    def _message_b(src, edge):
        return {'a': src['a'] * 3}
    def _reduce(node, msgs):
        assert msgs is not None
        return {'a': msgs['a'].max(1)[0]}

    old_repr = th.randn(3, 5)
    g.set_n_repr({'a': old_repr})
    g.send(0, 1, _message_a)
    g.send(0, 1, _message_b)
    g.recv([1], _reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], old_repr[0] * 3)

    g.set_n_repr({'a': old_repr})
    g.send(0, 1, _message_a)
    g.send(2, 1, _message_b)
    g.recv([1], _reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])

def test_send_multigraph():
    g = DGLGraph(multigraph=True)
    g.add_nodes(3)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(0, 1)
    g.add_edge(2, 1)

    def _message_a(src, edge):
        return {'a': edge['a']}
    def _message_b(src, edge):
        return {'a': edge['a'] * 3}
    def _reduce(node, msgs):
        assert msgs is not None
        return {'a': msgs['a'].max(1)[0]}

    def answer(*args):
        return th.stack(args, 0).max(0)[0]

    # send by eid
    old_repr = th.randn(4, 5)
    g.set_n_repr({'a': th.zeros(3, 5)})
    g.set_e_repr({'a': old_repr})
    g.send(eid=[0, 2], message_func=_message_a)
    g.recv([1], _reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))

    g.set_n_repr({'a': th.zeros(3, 5)})
    g.set_e_repr({'a': old_repr})
    g.send(eid=[0, 2, 3], message_func=_message_a)
    g.recv([1], _reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))

    # send on multigraph
    g.set_n_repr({'a': th.zeros(3, 5)})
    g.set_e_repr({'a': old_repr})
    g.send([0, 2], [1, 1], _message_a)
    g.recv([1], _reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], old_repr.max(0)[0])

    # consecutive send and send_on
    g.set_n_repr({'a': th.zeros(3, 5)})
    g.set_e_repr({'a': old_repr})
    g.send(2, 1, _message_a)
    g.send(eid=[0, 1], message_func=_message_b)
    g.recv([1], _reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))

    # consecutive send_on
    g.set_n_repr({'a': th.zeros(3, 5)})
    g.set_e_repr({'a': old_repr})
    g.send(eid=0, message_func=_message_a)
    g.send(eid=1, message_func=_message_b)
    g.recv([1], _reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))

    # send_and_recv_on
    g.set_n_repr({'a': th.zeros(3, 5)})
    g.set_e_repr({'a': old_repr})
    g.send_and_recv(eid=[0, 2, 3], message_func=_message_a, reduce_func=_reduce)
    new_repr = g.get_n_repr()['a']
    assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
    assert th.allclose(new_repr[[0, 2]], th.zeros(2, 5))


349
350
if __name__ == '__main__':
    test_batch_setter_getter()
351
    test_batch_setter_autograd()
352
    test_batch_send()
353
    test_batch_recv()
354
    test_update_routines()
355
    test_reduce_0deg()
356
    test_pull_0deg()
357
    test_send_multigraph()