test_basics.py 6.94 KB
Newer Older
1
import torch as th
2
3
from torch.autograd import Variable
import numpy as np
4
5
6
7
8
from dgl.graph import DGLGraph

D = 5
reduce_msg_shapes = set()

9
10
11
12
def check_eq(a, b):
    assert a.shape == b.shape
    assert th.sum(a == b) == int(np.prod(list(a.shape)))

13
14
15
16
17
18
def message_func(src, edge):
    assert len(src['h'].shape) == 2
    assert src['h'].shape[1] == D
    return {'m' : src['h']}

def reduce_func(node, msgs):
Minjie Wang's avatar
Minjie Wang committed
19
20
21
22
23
24
    msgs = msgs['m']
    reduce_msg_shapes.add(tuple(msgs.shape))
    assert len(msgs.shape) == 3
    assert msgs.shape[2] == D
    return {'m' : th.sum(msgs, 1)}

25
26
def apply_node_func(node):
    return {'h' : node['h'] + node['m']}
Minjie Wang's avatar
Minjie Wang committed
27

28
def generate_graph(grad=False):
29
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
30
    g.add_nodes(10) # 10 nodes.
31
32
33
34
35
36
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
Minjie Wang's avatar
Minjie Wang committed
37
38
    ncol = Variable(th.randn(10, D), requires_grad=grad)
    g.set_n_repr({'h' : ncol})
39
40
41
42
43
44
45
46
47
    return g

def test_batch_setter_getter():
    def _pfc(x):
        return list(x.numpy()[:,0])
    g = generate_graph()
    # set all nodes
    g.set_n_repr({'h' : th.zeros((10, D))})
    assert _pfc(g.get_n_repr()['h']) == [0.] * 10
Minjie Wang's avatar
Minjie Wang committed
48
49
50
51
    # pop nodes
    assert _pfc(g.pop_n_repr('h')) == [0.] * 10
    assert len(g.get_n_repr()) == 0
    g.set_n_repr({'h' : th.zeros((10, D))})
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    # set partial nodes
    u = th.tensor([1, 3, 5])
    g.set_n_repr({'h' : th.ones((3, D))}, u)
    assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
    # get partial nodes
    u = th.tensor([1, 2, 3])
    assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]

    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
    0, 2, 2
    2, 9, 3
    0, 3, 4
    3, 9, 5
    0, 4, 6
    4, 9, 7
    0, 5, 8
    5, 9, 9
    0, 6, 10
    6, 9, 11
    0, 7, 12
    7, 9, 13
    0, 8, 14
    8, 9, 15
    9, 0, 16
    '''
    # set all edges
    g.set_e_repr({'l' : th.zeros((17, D))})
    assert _pfc(g.get_e_repr()['l']) == [0.] * 17
Minjie Wang's avatar
Minjie Wang committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    # pop edges
    assert _pfc(g.pop_e_repr('l')) == [0.] * 17
    assert len(g.get_e_repr()) == 0
    g.set_e_repr({'l' : th.zeros((17, D))})
    # set partial edges (many-many)
    u = th.tensor([0, 0, 2, 5, 9])
    v = th.tensor([1, 3, 9, 9, 0])
    g.set_e_repr({'l' : th.ones((5, D))}, u, v)
    truth = [0.] * 17
    truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
    assert _pfc(g.get_e_repr()['l']) == truth
    # set partial edges (many-one)
    u = th.tensor([3, 4, 6])
    v = th.tensor([9])
    g.set_e_repr({'l' : th.ones((3, D))}, u, v)
    truth[5] = truth[7] = truth[11] = 1.
    assert _pfc(g.get_e_repr()['l']) == truth
    # set partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([4, 5, 6])
    g.set_e_repr({'l' : th.ones((3, D))}, u, v)
    truth[6] = truth[8] = truth[10] = 1.
    assert _pfc(g.get_e_repr()['l']) == truth
    # get partial edges (many-many)
    u = th.tensor([0, 6, 0])
    v = th.tensor([6, 9, 7])
    assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
    # get partial edges (many-one)
    u = th.tensor([5, 6, 7])
    v = th.tensor([9])
    assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
    # get partial edges (one-many)
    u = th.tensor([0])
    v = th.tensor([3, 4, 5])
    assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.]
118

119
120
121
122
123
124
125
126
127
128
129
130
def test_batch_setter_autograd():
    g = generate_graph(grad=True)
    h1 = g.get_n_repr()['h']
    # partial set
    v = th.tensor([1, 2, 8])
    hh = Variable(th.zeros((len(v), D)), requires_grad=True)
    g.set_n_repr({'h' : hh}, v)
    h2 = g.get_n_repr()['h']
    h2.backward(th.ones((10, D)) * 2)
    check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
    check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))

131
132
133
134
135
def test_batch_send():
    g = generate_graph()
    def _fmsg(src, edge):
        assert src['h'].shape == (5, D)
        return {'m' : src['h']}
Minjie Wang's avatar
Minjie Wang committed
136
    g.register_message_func(_fmsg)
137
    # many-many send
138
139
    u = th.tensor([0, 0, 0, 0, 0])
    v = th.tensor([1, 2, 3, 4, 5])
140
141
    g.send(u, v)
    # one-many send
142
143
    u = th.tensor([0])
    v = th.tensor([1, 2, 3, 4, 5])
144
145
    g.send(u, v)
    # many-one send
146
147
    u = th.tensor([1, 2, 3, 4, 5])
    v = th.tensor([9])
148
    g.send(u, v)
149

150
def test_batch_recv():
Minjie Wang's avatar
Minjie Wang committed
151
    # basic recv test
152
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
153
154
155
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
Minjie Wang's avatar
Minjie Wang committed
156
157
158
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
    reduce_msg_shapes.clear()
159
    g.send(u, v)
Minjie Wang's avatar
Minjie Wang committed
160
161
162
163
    g.recv(th.unique(v))
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

164
165
def test_update_routines():
    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
166
167
168
    g.register_message_func(message_func)
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_node_func)
169

170
    # send_and_recv
171
172
173
    reduce_msg_shapes.clear()
    u = th.tensor([0, 0, 0, 4, 5, 6])
    v = th.tensor([1, 2, 3, 9, 9, 9])
174
    g.send_and_recv(u, v)
175
176
177
    assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
    reduce_msg_shapes.clear()

178
    # pull
179
180
    v = th.tensor([1, 2, 3, 9])
    reduce_msg_shapes.clear()
181
    g.pull(v)
182
183
184
    assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
    reduce_msg_shapes.clear()

185
    # push
186
187
    v = th.tensor([0, 1, 2, 3])
    reduce_msg_shapes.clear()
188
    g.push(v)
189
190
191
192
193
194
195
196
197
    assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
    reduce_msg_shapes.clear()

    # update_all
    reduce_msg_shapes.clear()
    g.update_all()
    assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
    reduce_msg_shapes.clear()

198
199
def test_reduce_0deg():
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
200
    g.add_nodes(5)
201
202
203
204
205
206
207
208
    g.add_edge(1, 0)
    g.add_edge(2, 0)
    g.add_edge(3, 0)
    g.add_edge(4, 0)
    def _message(src, edge):
        return src
    def _reduce(node, msgs):
        assert msgs is not None
209
        return node + msgs.sum(1)
210
211
    old_repr = th.randn(5, 5)
    g.set_n_repr(old_repr)
Minjie Wang's avatar
Minjie Wang committed
212
    g.update_all(_message, _reduce)
213
214
215
216
217
    new_repr = g.get_n_repr()

    assert th.allclose(new_repr[1:], old_repr[1:])
    assert th.allclose(new_repr[0], old_repr.sum(0))

218
def test_pull_0deg():
219
    g = DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
220
    g.add_nodes(2)
221
222
223
224
225
226
227
228
229
    g.add_edge(0, 1)
    def _message(src, edge):
        return src
    def _reduce(node, msgs):
        assert msgs is not None
        return msgs.sum(1)

    old_repr = th.randn(2, 5)
    g.set_n_repr(old_repr)
Minjie Wang's avatar
Minjie Wang committed
230
    g.pull(0, _message, _reduce)
231
    new_repr = g.get_n_repr()
232
    assert th.allclose(new_repr[0], old_repr[0])
233
    assert th.allclose(new_repr[1], old_repr[1])
Minjie Wang's avatar
Minjie Wang committed
234
    g.pull(1, _message, _reduce)
235
    new_repr = g.get_n_repr()
236
    assert th.allclose(new_repr[1], old_repr[0])
237
238
239

    old_repr = th.randn(2, 5)
    g.set_n_repr(old_repr)
Minjie Wang's avatar
Minjie Wang committed
240
    g.pull([0, 1], _message, _reduce)
241
    new_repr = g.get_n_repr()
242
    assert th.allclose(new_repr[0], old_repr[0])
243
244
    assert th.allclose(new_repr[1], old_repr[0])

245
246
if __name__ == '__main__':
    test_batch_setter_getter()
247
    test_batch_setter_autograd()
248
    test_batch_send()
249
    test_batch_recv()
250
    test_update_routines()
251
    test_reduce_0deg()
252
    test_pull_0deg()