test_basics.py 7.73 KB
Newer Older
Da Zheng's avatar
Da Zheng committed
1
2
3
4
5
6
7
8
9
10
11
import os
os.environ['DGLBACKEND'] = 'mxnet'
import mxnet as mx
import numpy as np
from dgl.graph import DGLGraph

D = 5
reduce_msg_shapes = set()

def check_eq(a, b):
    assert a.shape == b.shape
Da Zheng's avatar
Da Zheng committed
12
    assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))
Da Zheng's avatar
Da Zheng committed
13

14
15
16
17
def message_func(edges):
    assert len(edges.src['h'].shape) == 2
    assert edges.src['h'].shape[1] == D
    return {'m' : edges.src['h']}
Da Zheng's avatar
Da Zheng committed
18

19
20
def reduce_func(nodes):
    msgs = nodes.mailbox['m']
Da Zheng's avatar
Da Zheng committed
21
22
23
24
25
    reduce_msg_shapes.add(tuple(msgs.shape))
    assert len(msgs.shape) == 3
    assert msgs.shape[2] == D
    return {'m' : mx.nd.sum(msgs, 1)}

26
27
def apply_node_func(nodes):
    return {'h' : nodes.data['h'] + nodes.data['m']}
Da Zheng's avatar
Da Zheng committed
28
29
30
31
32
33
34
35
36
37
38
39
40

def generate_graph(grad=False):
    g = DGLGraph()
    g.add_nodes(10) # 10 nodes.
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
    ncol = mx.nd.random.normal(shape=(10, D))
    if grad:
        ncol.attach_grad()
41
    g.ndata['h'] = ncol
Da Zheng's avatar
Da Zheng committed
42
43
44
45
46
47
48
49
    return g

def test_batch_setter_getter():
    def _pfc(x):
        return list(x.asnumpy()[:,0])
    g = generate_graph()
    # set all nodes
    g.set_n_repr({'h' : mx.nd.zeros((10, D))})
50
    assert _pfc(g.ndata['h']) == [0.] * 10
Da Zheng's avatar
Da Zheng committed
51
52
    # pop nodes
    assert _pfc(g.pop_n_repr('h')) == [0.] * 10
53
    assert len(g.ndata) == 0
Da Zheng's avatar
Da Zheng committed
54
55
56
57
    g.set_n_repr({'h' : mx.nd.zeros((10, D))})
    # set partial nodes
    u = mx.nd.array([1, 3, 5], dtype='int64')
    g.set_n_repr({'h' : mx.nd.ones((3, D))}, u)
58
    assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
Da Zheng's avatar
Da Zheng committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    # get partial nodes
    u = mx.nd.array([1, 2, 3], dtype='int64')
    assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]

    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
    0, 2, 2
    2, 9, 3
    0, 3, 4
    3, 9, 5
    0, 4, 6
    4, 9, 7
    0, 5, 8
    5, 9, 9
    0, 6, 10
    6, 9, 11
    0, 7, 12
    7, 9, 13
    0, 8, 14
    8, 9, 15
    9, 0, 16
    '''
    # set all edges
84
85
    g.edata['l'] = mx.nd.zeros((17, D))
    assert _pfc(g.edata['l']) == [0.] * 17
Da Zheng's avatar
Da Zheng committed
86
87
    # pop edges
    assert _pfc(g.pop_e_repr('l')) == [0.] * 17
88
89
    assert len(g.edata) == 0
    g.edata['l'] = mx.nd.zeros((17, D))
Da Zheng's avatar
Da Zheng committed
90
91
92
    # set partial edges (many-many)
    u = mx.nd.array([0, 0, 2, 5, 9], dtype='int64')
    v = mx.nd.array([1, 3, 9, 9, 0], dtype='int64')
93
    g.edges[u, v].data['l'] = mx.nd.ones((5, D))
Da Zheng's avatar
Da Zheng committed
94
95
    truth = [0.] * 17
    truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
96
    assert _pfc(g.edata['l']) == truth
Da Zheng's avatar
Da Zheng committed
97
98
99
    # set partial edges (many-one)
    u = mx.nd.array([3, 4, 6], dtype='int64')
    v = mx.nd.array([9], dtype='int64')
100
    g.edges[u, v].data['l'] = mx.nd.ones((3, D))
Da Zheng's avatar
Da Zheng committed
101
    truth[5] = truth[7] = truth[11] = 1.
102
    assert _pfc(g.edata['l']) == truth
Da Zheng's avatar
Da Zheng committed
103
104
105
    # set partial edges (one-many)
    u = mx.nd.array([0], dtype='int64')
    v = mx.nd.array([4, 5, 6], dtype='int64')
106
    g.edges[u, v].data['l'] = mx.nd.ones((3, D))
Da Zheng's avatar
Da Zheng committed
107
    truth[6] = truth[8] = truth[10] = 1.
108
    assert _pfc(g.edata['l']) == truth
Da Zheng's avatar
Da Zheng committed
109
110
111
    # get partial edges (many-many)
    u = mx.nd.array([0, 6, 0], dtype='int64')
    v = mx.nd.array([6, 9, 7], dtype='int64')
112
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Da Zheng's avatar
Da Zheng committed
113
114
115
    # get partial edges (many-one)
    u = mx.nd.array([5, 6, 7], dtype='int64')
    v = mx.nd.array([9], dtype='int64')
116
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
Da Zheng's avatar
Da Zheng committed
117
118
119
    # get partial edges (one-many)
    u = mx.nd.array([0], dtype='int64')
    v = mx.nd.array([3, 4, 5], dtype='int64')
120
    assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
Da Zheng's avatar
Da Zheng committed
121
122
123
124

def test_batch_setter_autograd():
    with mx.autograd.record():
        g = generate_graph(grad=True)
125
        h1 = g.ndata['h']
Da Zheng's avatar
Da Zheng committed
126
        h1.attach_grad()
Da Zheng's avatar
Da Zheng committed
127
128
129
        # partial set
        v = mx.nd.array([1, 2, 8], dtype='int64')
        hh = mx.nd.zeros((len(v), D))
Da Zheng's avatar
Da Zheng committed
130
        hh.attach_grad()
Da Zheng's avatar
Da Zheng committed
131
        g.set_n_repr({'h' : hh}, v)
132
        h2 = g.ndata['h']
Da Zheng's avatar
Da Zheng committed
133
134
135
136
137
138
    h2.backward(mx.nd.ones((10, D)) * 2)
    check_eq(h1.grad[:,0], mx.nd.array([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
    check_eq(hh.grad[:,0], mx.nd.array([2., 2., 2.]))

def test_batch_send():
    g = generate_graph()
139
140
141
    def _fmsg(edges):
        assert edges.src['h'].shape == (5, D)
        return {'m' : edges.src['h']}
Da Zheng's avatar
Da Zheng committed
142
143
144
145
    g.register_message_func(_fmsg)
    # many-many send
    u = mx.nd.array([0, 0, 0, 0, 0], dtype='int64')
    v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
146
    g.send((u, v))
Da Zheng's avatar
Da Zheng committed
147
148
149
    # one-many send
    u = mx.nd.array([0], dtype='int64')
    v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
150
    g.send((u, v))
Da Zheng's avatar
Da Zheng committed
151
152
153
    # many-one send
    u = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
    v = mx.nd.array([9], dtype='int64')
154
    g.send((u, v))
Da Zheng's avatar
Da Zheng committed
155
156
157
158
159
160
161
162
163
164

def test_batch_recv():
    # basic recv test
    g = generate_graph()
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
    u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
    v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
    reduce_msg_shapes.clear()
165
    g.send((u, v))
Da Zheng's avatar
Da Zheng committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    #g.recv(th.unique(v))
    #assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    #reduce_msg_shapes.clear()

def test_update_routines():
    g = generate_graph()
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)

    # send_and_recv
    reduce_msg_shapes.clear()
    u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
    v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
180
    g.send_and_recv((u, v))
Da Zheng's avatar
Da Zheng committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

    # pull
    v = mx.nd.array([1, 2, 3, 9], dtype='int64')
    reduce_msg_shapes.clear()
    g.pull(v)
    assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
    reduce_msg_shapes.clear()

    # push
    v = mx.nd.array([0, 1, 2, 3], dtype='int64')
    reduce_msg_shapes.clear()
    g.push(v)
    assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
    reduce_msg_shapes.clear()

    # update_all
    reduce_msg_shapes.clear()
    g.update_all()
    assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
    reduce_msg_shapes.clear()

def test_reduce_0deg():
    g = DGLGraph()
    g.add_nodes(5)
    g.add_edge(1, 0)
    g.add_edge(2, 0)
    g.add_edge(3, 0)
    g.add_edge(4, 0)
211
212
213
214
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
Da Zheng's avatar
Da Zheng committed
215
    old_repr = mx.nd.random.normal(shape=(5, 5))
Da Zheng's avatar
Da Zheng committed
216
    g.set_n_repr({'h': old_repr})
Da Zheng's avatar
Da Zheng committed
217
    g.update_all(_message, _reduce)
218
    new_repr = g.ndata['h']
Da Zheng's avatar
Da Zheng committed
219
220
221
222
223
224
225
226

    assert np.allclose(new_repr[1:].asnumpy(), old_repr[1:].asnumpy())
    assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy())

def test_pull_0deg():
    g = DGLGraph()
    g.add_nodes(2)
    g.add_edge(0, 1)
227
228
229
230
    def _message(edges):
        return {'m' : edges.src['h']}
    def _reduce(nodes):
        return {'h' : nodes.mailbox['m'].sum(1)}
Da Zheng's avatar
Da Zheng committed
231
232

    old_repr = mx.nd.random.normal(shape=(2, 5))
Da Zheng's avatar
Da Zheng committed
233
    g.set_n_repr({'h' : old_repr})
Da Zheng's avatar
Da Zheng committed
234
    g.pull(0, _message, _reduce)
235
    new_repr = g.ndata['h']
Da Zheng's avatar
Da Zheng committed
236
237
238
    assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
    assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy())
    g.pull(1, _message, _reduce)
239
    new_repr = g.ndata['h']
Da Zheng's avatar
Da Zheng committed
240
241
242
    assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())

    old_repr = mx.nd.random.normal(shape=(2, 5))
Da Zheng's avatar
Da Zheng committed
243
    g.set_n_repr({'h' : old_repr})
Da Zheng's avatar
Da Zheng committed
244
    g.pull([0, 1], _message, _reduce)
245
    new_repr = g.ndata['h']
Da Zheng's avatar
Da Zheng committed
246
247
248
249
250
    assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
    assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())

if __name__ == '__main__':
    test_batch_setter_getter()
Da Zheng's avatar
Da Zheng committed
251
    test_batch_setter_autograd()
Da Zheng's avatar
Da Zheng committed
252
253
254
255
256
    test_batch_send()
    test_batch_recv()
    test_update_routines()
    test_reduce_0deg()
    test_pull_0deg()