test_kernel.py 8.7 KB
Newer Older
1
2
3
import dgl
import dgl.function as fn
import networkx as nx
4
import numpy as np
5
6
7
import backend as F
from itertools import product

8
np.random.seed(42)
9
10
11
12
13
14
15
16

def udf_copy_src(edges):
    return {'m': edges.src['u']}


def udf_copy_edge(edges):
    return {'m': edges.data['e']}

17
18
def udf_mean(nodes):
    return {'r2': nodes.mailbox['m'].mean(1)}
19
20
21
22
23
24
25
26
27
28
29
30

def udf_sum(nodes):
    return {'r2': nodes.mailbox['m'].sum(1)}


def udf_max(nodes):
    return {'r2': F.max(nodes.mailbox['m'], 1)}


D1 = 5
D2 = 3
D3 = 4
31
32
builtin = {'sum': fn.sum, 'max': fn.max, 'mean': fn.mean}
udf_reduce = {'sum': udf_sum, 'max': udf_max, 'mean': udf_mean}
33
34
35
36
37
38
39
40
41
42
fill_value = {'sum': 0, 'max': float("-inf")}


def generate_feature(g, broadcast='none'):
    """Create graph with src, edge, dst feature. broadcast can be 'u',
    'e', 'v', 'none'
    """
    nv = g.number_of_nodes()
    ne = g.number_of_edges()
    if broadcast == 'e':
43
44
45
        u = F.tensor(np.random.randn(nv, D1, D2, D3) + 1)
        e = F.tensor(np.random.randn(ne, D2, 1) - 1)
        v = F.tensor(np.random.randn(nv, D1, D2, D3))
46
    elif broadcast == 'u':
47
48
49
        u = F.tensor(np.random.randn(nv, D2, 1) + 1)
        e = F.tensor(np.random.randn(ne, D1, D2, D3) - 1)
        v = F.tensor(np.random.randn(nv, D1, D2, D3))
50
    elif broadcast == 'v':
51
52
53
        u = F.tensor(np.random.randn(nv, D1, D2, D3) + 1)
        e = F.tensor(np.random.randn(ne, D1, D2, D3) - 1)
        v = F.tensor(np.random.randn(nv, D2, 1))
54
    else:
55
56
57
        u = F.tensor(np.random.randn(nv, D1, D2, D3) + 1)
        e = F.tensor(np.random.randn(ne, D1, D2, D3) - 1)
        v = F.tensor(np.random.randn(nv, D1, D2, D3))
58
59
60
61
    return u, v, e


def test_copy_src_reduce():
62
    def _test(red, partial):
63
64
        g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1))
        hu, hv, he = generate_feature(g, 'none')
65
66
        if partial:
            nid = F.tensor(list(range(0, 100, 2)))
67
68
69
70
71
72

        g.ndata['u'] = F.attach_grad(F.clone(hu))
        g.ndata['v'] = F.attach_grad(F.clone(hv))
        g.edata['e'] = F.attach_grad(F.clone(he))

        with F.record_grad():
73
74
75
76
77
78
            if partial:
                g.pull(nid, fn.copy_src(src='u', out='m'),
                       builtin[red](msg='m', out='r1'))
            else:
                g.update_all(fn.copy_src(src='u', out='m'),
                             builtin[red](msg='m', out='r1'))
79
80
81
82
83
84
85
86
87
88
            r1 = g.ndata['r1']
            F.backward(r1.sum())
            n_grad1 = F.grad(g.ndata['u'])

        # reset grad
        g.ndata['u'] = F.attach_grad(F.clone(hu))
        g.ndata['v'] = F.attach_grad(F.clone(hv))
        g.edata['e'] = F.attach_grad(F.clone(he))

        with F.record_grad():
89
90
91
92
            if partial:
                g.pull(nid, udf_copy_src, udf_reduce[red])
            else:
                g.update_all(udf_copy_src, udf_reduce[red])
93
94
95
96
97
98
99
            r2 = g.ndata['r2']
            F.backward(r2.sum())
            n_grad2 = F.grad(g.ndata['u'])

        assert F.allclose(r1, r2)
        assert(F.allclose(n_grad1, n_grad2))

100
101
102
103
104
105
106
107
    _test('sum', False)
    _test('max', False)
    _test('mean', False)
    _test('sum', True)
    _test('max', True)
    _test('mean', True)


108
109
110


def test_copy_edge_reduce():
111
    def _test(red, partial):
112
113
        g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1))
        hu, hv, he = generate_feature(g, 'none')
114
115
116
        if partial:
            nid = F.tensor(list(range(0, 100, 2)))

117
118
119
120
121
        g.ndata['u'] = F.attach_grad(F.clone(hu))
        g.ndata['v'] = F.attach_grad(F.clone(hv))
        g.edata['e'] = F.attach_grad(F.clone(he))

        with F.record_grad():
122
123
124
125
126
127
            if partial:
                g.pull(nid, fn.copy_edge(edge='e', out='m'),
                       builtin[red](msg='m', out='r1'))
            else:
                g.update_all(fn.copy_edge(edge='e', out='m'),
                             builtin[red](msg='m', out='r1'))
128
129
130
131
132
133
134
135
136
137
            r1 = g.ndata['r1']
            F.backward(r1.sum())
            e_grad1 = F.grad(g.edata['e'])

        # reset grad
        g.ndata['u'] = F.attach_grad(F.clone(hu))
        g.ndata['v'] = F.attach_grad(F.clone(hv))
        g.edata['e'] = F.attach_grad(F.clone(he))

        with F.record_grad():
138
139
140
141
            if partial:
                g.pull(nid, udf_copy_edge, udf_reduce[red])
            else:
                g.update_all(udf_copy_edge, udf_reduce[red])
142
143
144
145
146
147
148
            r2 = g.ndata['r2']
            F.backward(r2.sum())
            e_grad2 = F.grad(g.edata['e'])

        assert F.allclose(r1, r2)
        assert(F.allclose(e_grad1, e_grad2))

149
150
151
152
153
154
    _test('sum', False)
    _test('max', False)
    _test('mean', False)
    _test('sum', True)
    _test('max', True)
    _test('mean', True)
155
156
157


def test_all_binary_builtins():
158
    def _test(g, lhs, rhs, binary_op, reducer, paritial, nid, broadcast='none'):
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        hu, hv, he = generate_feature(g, broadcast)
        g.ndata['u'] = F.attach_grad(F.clone(hu))
        g.ndata['v'] = F.attach_grad(F.clone(hv))
        g.edata['e'] = F.attach_grad(F.clone(he))

        builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
        builtin_msg = getattr(fn, builtin_msg_name)
        builtin_red = getattr(fn, reducer)

        def target_feature_switch(g, target):
            if target == "u":
                return g.ndata["u"]
            elif target == "v":
                return g.ndata["v"]
            else:
                return g.edata["e"]

        with F.record_grad():
177
178
179
180
181
            if partial:
                g.pull(nid, builtin_msg(lhs, rhs, 'm'), builtin_red('m', 'r1'))
            else:
                g.update_all(builtin_msg(lhs, rhs, 'm'), builtin_red('m', 'r1'))
            r1 = g.ndata.pop('r1')
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
            F.backward(r1.sum())
            lhs_grad_1 = F.grad(target_feature_switch(g, lhs))
            rhs_grad_1 = F.grad(target_feature_switch(g, rhs))

        # reset grad
        g.ndata['u'] = F.attach_grad(F.clone(hu))
        g.ndata['v'] = F.attach_grad(F.clone(hv))
        g.edata['e'] = F.attach_grad(F.clone(he))

        def target_switch(edges, target):
            if target == "u":
                return edges.src
            elif target == "v":
                return edges.dst
            elif target == "e":
                return edges.data
            else:
                assert(0), "Unknown target {}".format(target)

        def mfunc(edges):
            op = getattr(F, binary_op)
            lhs_data = target_switch(edges, lhs)
            rhs_data = target_switch(edges, rhs)
            return {"m": op(lhs_data[lhs], rhs_data[rhs])}

        def rfunc(nodes):
            op = getattr(F, reducer)
            return {"r2": op(nodes.mailbox['m'], 1)}

        with F.record_grad():
212
213
214
215
216
            if partial:
                g.pull(nid, mfunc, rfunc)
            else:
                g.update_all(mfunc, rfunc)
            r2 = g.ndata.pop('r2')
217
            F.backward(r2.sum(), F.tensor([1.]))
218
219
220
            lhs_grad_2 = F.grad(target_feature_switch(g, lhs))
            rhs_grad_2 = F.grad(target_feature_switch(g, rhs))

221
222
223
224
225
226
227
        if reducer == 'prod':
            rtol = 1e-2
            atol = 1e-2
        else:
            rtol = 1e-4
            atol = 1e-4

228
229
230
231
        def _print_error(a, b):
            print("ERROR: Test {}_{}_{}_{} {}".
                  format(lhs, binary_op, rhs, reducer, broadcast))
            print(a, b)
232
            for i, (x, y) in enumerate(zip(F.asnumpy(F.cpu(a)).flatten(), F.asnumpy(F.cpu(b)).flatten())):
233
234
235
                if not np.allclose(x, y, rtol, atol):
                    print('@{} {} v.s. {}'.format(i, x, y))

236
        if not F.allclose(r1, r2, rtol, atol):
237
            _print_error(r1, r2)
238
        assert F.allclose(r1, r2, rtol, atol)
239
240

        if not F.allclose(lhs_grad_1, lhs_grad_2, rtol, atol):
241
242
            print("left grad")
            _print_error(lhs_grad_1, lhs_grad_2)
243
        assert(F.allclose(lhs_grad_1, lhs_grad_2, rtol, atol))
244

245
        if not F.allclose(rhs_grad_1, rhs_grad_2, rtol, atol):
246
247
            print("right grad")
            _print_error(rhs_grad_1, rhs_grad_2)
248
        assert(F.allclose(rhs_grad_1, rhs_grad_2, rtol, atol))
249
250
251
252
253
254
255
256
257
258
259
260

    g = dgl.DGLGraph()
    g.add_nodes(20)
    for i in range(2, 18):
        g.add_edge(0, i)
        g.add_edge(1, i)
        g.add_edge(i, 18)
        g.add_edge(i, 19)
    g.add_edge(18, 0)
    g.add_edge(18, 1)
    g.add_edge(19, 0)
    g.add_edge(19, 1)
261
    nid = F.tensor([1, 3, 4, 5, 7, 10, 13, 17, 19])
262
263
264
265
266
    target = ["u", "v", "e"]
    for lhs, rhs in product(target, target):
        if lhs == rhs:
            continue
        for binary_op in ["add", "sub", "mul", "div"]:
267
            for reducer in ["sum", "max", "min", "prod", "mean"]:
268
                for broadcast in ["none", lhs, rhs]:
269
270
                    for partial in [False, True]:
                        _test(g, lhs, rhs, binary_op, reducer, partial, nid)
271
272
273
274
275

if __name__ == '__main__':
    test_copy_src_reduce()
    test_copy_edge_reduce()
    test_all_binary_builtins()