test_basics.py 9.87 KB
Newer Older
Da Zheng's avatar
Da Zheng committed
1
2
3
4
5
import os
os.environ['DGLBACKEND'] = 'mxnet'
import mxnet as mx
import numpy as np
from dgl.graph import DGLGraph
6
import scipy.sparse as spsp
Da Zheng's avatar
Da Zheng committed
7
8
9
10
11
12

D = 5
reduce_msg_shapes = set()

def check_eq(a, b):
    assert a.shape == b.shape
Da Zheng's avatar
Da Zheng committed
13
    assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))
Da Zheng's avatar
Da Zheng committed
14

15
16
17
18
def message_func(edges):
    assert len(edges.src['h'].shape) == 2
    assert edges.src['h'].shape[1] == D
    return {'m' : edges.src['h']}
Da Zheng's avatar
Da Zheng committed
19

20
21
def reduce_func(nodes):
    msgs = nodes.mailbox['m']
Da Zheng's avatar
Da Zheng committed
22
23
24
25
26
    reduce_msg_shapes.add(tuple(msgs.shape))
    assert len(msgs.shape) == 3
    assert msgs.shape[2] == D
    return {'m' : mx.nd.sum(msgs, 1)}

27
28
def apply_node_func(nodes):
    return {'h' : nodes.data['h'] + nodes.data['m']}
Da Zheng's avatar
Da Zheng committed
29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def generate_graph(grad=False, readonly=False):
    if readonly:
        row_idx = []
        col_idx = []
        for i in range(1, 9):
            row_idx.append(0)
            col_idx.append(i)
            row_idx.append(i)
            col_idx.append(9)
        row_idx.append(9)
        col_idx.append(0)
        ones = np.ones(shape=(len(row_idx)))
        csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(10, 10))
        g = DGLGraph(csr, readonly=True)
        ncol = mx.nd.random.normal(shape=(10, D))
        if grad:
            ncol.attach_grad()
        g.ndata['h'] = ncol
        return g
    else:
        g = DGLGraph()
        g.add_nodes(10) # 10 nodes.
        # create a graph where 0 is the source and 9 is the sink
        for i in range(1, 9):
            g.add_edge(0, i)
            g.add_edge(i, 9)
        # add a back flow from 9 to 0
        g.add_edge(9, 0)
        ncol = mx.nd.random.normal(shape=(10, D))
        if grad:
            ncol.attach_grad()
        g.ndata['h'] = ncol
        return g
Da Zheng's avatar
Da Zheng committed
63
64
65
66
67
68
69

def test_batch_setter_getter():
    def _pfc(x):
        return list(x.asnumpy()[:,0])
    g = generate_graph()
    # set all nodes
    g.set_n_repr({'h' : mx.nd.zeros((10, D))})
70
    assert _pfc(g.ndata['h']) == [0.] * 10
Da Zheng's avatar
Da Zheng committed
71
72
    # pop nodes
    assert _pfc(g.pop_n_repr('h')) == [0.] * 10
73
    assert len(g.ndata) == 0
Da Zheng's avatar
Da Zheng committed
74
75
76
77
    g.set_n_repr({'h' : mx.nd.zeros((10, D))})
    # set partial nodes
    u = mx.nd.array([1, 3, 5], dtype='int64')
    g.set_n_repr({'h' : mx.nd.ones((3, D))}, u)
78
    assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
Da Zheng's avatar
Da Zheng committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    # get partial nodes
    u = mx.nd.array([1, 2, 3], dtype='int64')
    assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]

    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
    0, 2, 2
    2, 9, 3
    0, 3, 4
    3, 9, 5
    0, 4, 6
    4, 9, 7
    0, 5, 8
    5, 9, 9
    0, 6, 10
    6, 9, 11
    0, 7, 12
    7, 9, 13
    0, 8, 14
    8, 9, 15
    9, 0, 16
    '''
    # set all edges
104
105
    g.edata['l'] = mx.nd.zeros((17, D))
    assert _pfc(g.edata['l']) == [0.] * 17
Da Zheng's avatar
Da Zheng committed
106
107
    # pop edges
    assert _pfc(g.pop_e_repr('l')) == [0.] * 17
108
109
    assert len(g.edata) == 0
    g.edata['l'] = mx.nd.zeros((17, D))
Da Zheng's avatar
Da Zheng committed
110
111
112
    # set partial edges (many-many)
    u = mx.nd.array([0, 0, 2, 5, 9], dtype='int64')
    v = mx.nd.array([1, 3, 9, 9, 0], dtype='int64')
113
    g.edges[u, v].data['l'] = mx.nd.ones((5, D))
Da Zheng's avatar
Da Zheng committed
114
115
    truth = [0.] * 17
    truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
116
    assert _pfc(g.edata['l']) == truth
Da Zheng's avatar
Da Zheng committed
117
118
119
    # set partial edges (many-one)
    u = mx.nd.array([3, 4, 6], dtype='int64')
    v = mx.nd.array([9], dtype='int64')
120
    g.edges[u, v].data['l'] = mx.nd.ones((3, D))
Da Zheng's avatar
Da Zheng committed
121
    truth[5] = truth[7] = truth[11] = 1.
122
    assert _pfc(g.edata['l']) == truth
Da Zheng's avatar
Da Zheng committed
123
124
125
    # set partial edges (one-many)
    u = mx.nd.array([0], dtype='int64')
    v = mx.nd.array([4, 5, 6], dtype='int64')
126
    g.edges[u, v].data['l'] = mx.nd.ones((3, D))
Da Zheng's avatar
Da Zheng committed
127
    truth[6] = truth[8] = truth[10] = 1.
128
    assert _pfc(g.edata['l']) == truth
Da Zheng's avatar
Da Zheng committed
129
130
131
    # get partial edges (many-many)
    u = mx.nd.array([0, 6, 0], dtype='int64')
    v = mx.nd.array([6, 9, 7], dtype='int64')
132
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Da Zheng's avatar
Da Zheng committed
133
134
135
    # get partial edges (many-one)
    u = mx.nd.array([5, 6, 7], dtype='int64')
    v = mx.nd.array([9], dtype='int64')
136
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Da Zheng's avatar
Da Zheng committed
137
138
139
    # get partial edges (one-many)
    u = mx.nd.array([0], dtype='int64')
    v = mx.nd.array([3, 4, 5], dtype='int64')
140
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
Da Zheng's avatar
Da Zheng committed
141
142
143

def test_batch_setter_autograd():
    with mx.autograd.record():
144
        g = generate_graph(grad=True, readonly=True)
145
        h1 = g.ndata['h']
Da Zheng's avatar
Da Zheng committed
146
        h1.attach_grad()
Da Zheng's avatar
Da Zheng committed
147
148
149
        # partial set
        v = mx.nd.array([1, 2, 8], dtype='int64')
        hh = mx.nd.zeros((len(v), D))
Da Zheng's avatar
Da Zheng committed
150
        hh.attach_grad()
Da Zheng's avatar
Da Zheng committed
151
        g.set_n_repr({'h' : hh}, v)
152
        h2 = g.ndata['h']
Da Zheng's avatar
Da Zheng committed
153
154
155
156
157
158
    h2.backward(mx.nd.ones((10, D)) * 2)
    check_eq(h1.grad[:,0], mx.nd.array([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
    check_eq(hh.grad[:,0], mx.nd.array([2., 2., 2.]))

def test_batch_send():
    g = generate_graph()
159
160
161
    def _fmsg(edges):
        assert edges.src['h'].shape == (5, D)
        return {'m' : edges.src['h']}
Da Zheng's avatar
Da Zheng committed
162
163
164
165
    g.register_message_func(_fmsg)
    # many-many send
    u = mx.nd.array([0, 0, 0, 0, 0], dtype='int64')
    v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
166
    g.send((u, v))
Da Zheng's avatar
Da Zheng committed
167
168
169
    # one-many send
    u = mx.nd.array([0], dtype='int64')
    v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
170
    g.send((u, v))
Da Zheng's avatar
Da Zheng committed
171
172
173
    # many-one send
    u = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
    v = mx.nd.array([9], dtype='int64')
174
    g.send((u, v))
Da Zheng's avatar
Da Zheng committed
175

176
def check_batch_recv(readonly):
Da Zheng's avatar
Da Zheng committed
177
    # basic recv test
178
    g = generate_graph(readonly=readonly)
Da Zheng's avatar
Da Zheng committed
179
180
181
182
183
184
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
    u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
    v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
    reduce_msg_shapes.clear()
185
    g.send((u, v))
Da Zheng's avatar
Da Zheng committed
186
187
188
189
    #g.recv(th.unique(v))
    #assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    #reduce_msg_shapes.clear()

190
191
192
193
194
195
def test_batch_recv():
    check_batch_recv(True)
    check_batch_recv(False)

def check_update_routines(readonly):
    g = generate_graph(readonly=readonly)
Da Zheng's avatar
Da Zheng committed
196
197
198
199
200
201
202
203
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)

    # send_and_recv
    reduce_msg_shapes.clear()
    u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
    v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
204
    g.send_and_recv((u, v))
Da Zheng's avatar
Da Zheng committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

    # pull
    v = mx.nd.array([1, 2, 3, 9], dtype='int64')
    reduce_msg_shapes.clear()
    g.pull(v)
    assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
    reduce_msg_shapes.clear()

    # push
    v = mx.nd.array([0, 1, 2, 3], dtype='int64')
    reduce_msg_shapes.clear()
    g.push(v)
    assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
    reduce_msg_shapes.clear()

    # update_all
    reduce_msg_shapes.clear()
    g.update_all()
    assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
    reduce_msg_shapes.clear()

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def test_update_routines():
    check_update_routines(True)
    check_update_routines(False)

def check_reduce_0deg(readonly):
    if readonly:
        row_idx = []
        col_idx = []
        for i in range(1, 5):
            row_idx.append(i)
            col_idx.append(0)
        ones = np.ones(shape=(len(row_idx)))
        csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(5, 5))
        g = DGLGraph(csr, readonly=True)
    else:
        g = DGLGraph()
        g.add_nodes(5)
        g.add_edge(1, 0)
        g.add_edge(2, 0)
        g.add_edge(3, 0)
        g.add_edge(4, 0)
249
250
251
252
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
253
254
255
    def _init2(shape, dtype, ctx, ids):
        return 2 + mx.nd.zeros(shape, dtype=dtype, ctx=ctx)
    g.set_n_initializer(_init2, 'h')
Da Zheng's avatar
Da Zheng committed
256
    old_repr = mx.nd.random.normal(shape=(5, 5))
Da Zheng's avatar
Da Zheng committed
257
    g.set_n_repr({'h': old_repr})
Da Zheng's avatar
Da Zheng committed
258
    g.update_all(_message, _reduce)
259
    new_repr = g.ndata['h']
Da Zheng's avatar
Da Zheng committed
260

261
    assert np.allclose(new_repr[1:].asnumpy(), 2+mx.nd.zeros((4, 5)).asnumpy())
Da Zheng's avatar
Da Zheng committed
262
263
    assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy())

264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def test_reduce_0deg():
    check_reduce_0deg(True)
    check_reduce_0deg(False)

def check_pull_0deg(readonly):
    if readonly:
        row_idx = []
        col_idx = []
        row_idx.append(0)
        col_idx.append(1)
        ones = np.ones(shape=(len(row_idx)))
        csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(2, 2))
        g = DGLGraph(csr, readonly=True)
    else:
        g = DGLGraph()
        g.add_nodes(2)
        g.add_edge(0, 1)
281
282
283
284
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.mailbox['m'].sum(1)}
Da Zheng's avatar
Da Zheng committed
285
    old_repr = mx.nd.random.normal(shape=(2, 5))
Da Zheng's avatar
Da Zheng committed
286
    g.set_n_repr({'h' : old_repr})
Da Zheng's avatar
Da Zheng committed
287
    g.pull(0, _message, _reduce)
288
    new_repr = g.ndata['h']
289
290
291
292
293
    # TODO(minjie): this is not the intended behavior. Pull node#0
    #   should reset node#0 to the initial value. The bug is because
    #   current pull is implemented using send_and_recv. Since there
    #   is no edge to node#0 so the send_and_recv is skipped. Fix this
    #   behavior when optimizing the pull scheduler.
Da Zheng's avatar
Da Zheng committed
294
295
296
    assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
    assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy())
    g.pull(1, _message, _reduce)
297
    new_repr = g.ndata['h']
Da Zheng's avatar
Da Zheng committed
298
299
300
    assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())

    old_repr = mx.nd.random.normal(shape=(2, 5))
Da Zheng's avatar
Da Zheng committed
301
    g.set_n_repr({'h' : old_repr})
Da Zheng's avatar
Da Zheng committed
302
    g.pull([0, 1], _message, _reduce)
303
    new_repr = g.ndata['h']
Da Zheng's avatar
Da Zheng committed
304
305
306
    assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
    assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())

307
308
309
310
def test_pull_0deg():
    check_pull_0deg(True)
    check_pull_0deg(False)

Da Zheng's avatar
Da Zheng committed
311
312
if __name__ == '__main__':
    test_batch_setter_getter()
Da Zheng's avatar
Da Zheng committed
313
    test_batch_setter_autograd()
Da Zheng's avatar
Da Zheng committed
314
315
316
317
318
    test_batch_send()
    test_batch_recv()
    test_update_routines()
    test_reduce_0deg()
    test_pull_0deg()