test_heterograph-kernel.py 13.4 KB
Newer Older
1
2
3
import dgl
import dgl.function as fn
import networkx as nx
4
import numpy as np
5
6
import backend as F
from itertools import product
nv-dlasalle's avatar
nv-dlasalle committed
7
from test_utils import parametrize_idtype, get_cases
8
import pytest
9
10
11
12
13
14
15

def udf_copy_src(edges):
    return {'m': edges.src['u']}

def udf_copy_edge(edges):
    return {'m': edges.data['e']}

16
def udf_mean(nodes):
VoVAllen's avatar
VoVAllen committed
17
    return {'r2': F.mean(nodes.mailbox['m'], 1)}
18
19

def udf_sum(nodes):
VoVAllen's avatar
VoVAllen committed
20
    return {'r2': F.sum(nodes.mailbox['m'], 1)}
21
22
23
24
25
26
27
28

def udf_max(nodes):
    return {'r2': F.max(nodes.mailbox['m'], 1)}


D1 = 5
D2 = 3
D3 = 4
29
D4 = 10 # NOTE(xiang): used to dot feature vector
30
31
builtin = {'sum': fn.sum, 'max': fn.max, 'mean': fn.mean}
udf_reduce = {'sum': udf_sum, 'max': udf_max, 'mean': udf_mean}
32
33
34
fill_value = {'sum': 0, 'max': float("-inf")}


35
def generate_feature(g, broadcast='none', binary_op='none'):
36
37
38
    """Create graph with src, edge, dst feature. broadcast can be 'u',
    'e', 'v', 'none'
    """
39
    np.random.seed(31)
40
41
    nv = g.number_of_nodes()
    ne = g.number_of_edges()
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    if binary_op == 'dot':
        if broadcast == 'e':
            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D2, 1, D4)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
        elif broadcast == 'u':
            u = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1, D4)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
        elif broadcast == 'v':
            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1, D4)))
        else:
            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
59
    else:
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        if broadcast == 'e':
            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D2, 1)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
        elif broadcast == 'u':
            u = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
        elif broadcast == 'v':
            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1)))
        else:
            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
76
    return F.astype(u, F.float32), F.astype(v, F.float32), F.astype(e, F.float32)
77
78
79


def test_copy_src_reduce():
80
    def _test(red, partial):
81
        g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1))
82
83
84
        # NOTE(zihao): add self-loop to avoid zero-degree nodes.
        # https://github.com/dmlc/dgl/issues/761
        g.add_edges(g.nodes(), g.nodes())
85
        g = g.to(F.ctx())
86
        hu, hv, he = generate_feature(g, 'none', 'none')
87
        if partial:
88
            nid = F.tensor(list(range(0, 100, 2)), g.idtype)
89
90
91
92
93
94

        g.ndata['u'] = F.attach_grad(F.clone(hu))
        g.ndata['v'] = F.attach_grad(F.clone(hv))
        g.edata['e'] = F.attach_grad(F.clone(he))

        with F.record_grad():
95
            if partial:
96
                g.pull(nid, fn.copy_u(u='u', out='m'),
97
98
                       builtin[red](msg='m', out='r1'))
            else:
99
                g.update_all(fn.copy_u(u='u', out='m'),
100
                             builtin[red](msg='m', out='r1'))
101
            r1 = g.ndata['r1']
VoVAllen's avatar
VoVAllen committed
102
            F.backward(F.reduce_sum(r1))
103
104
105
106
107
108
109
110
            n_grad1 = F.grad(g.ndata['u'])

        # reset grad
        g.ndata['u'] = F.attach_grad(F.clone(hu))
        g.ndata['v'] = F.attach_grad(F.clone(hv))
        g.edata['e'] = F.attach_grad(F.clone(he))

        with F.record_grad():
111
112
113
114
            if partial:
                g.pull(nid, udf_copy_src, udf_reduce[red])
            else:
                g.update_all(udf_copy_src, udf_reduce[red])
115
            r2 = g.ndata['r2']
VoVAllen's avatar
VoVAllen committed
116
            F.backward(F.reduce_sum(r2))
117
118
            n_grad2 = F.grad(g.ndata['u'])

119
120
121
122
123
124
125
126
127
        def _print_error(a, b):
            print("ERROR: Test copy_src_{} partial: {}".
                  format(red, partial))
            for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
                if not np.allclose(x, y):
                    print('@{} {} v.s. {}'.format(i, x, y))

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
128
        assert F.allclose(r1, r2)
129
130
131
        if not F.allclose(n_grad1, n_grad2):
            print('node grad')
            _print_error(n_grad1, n_grad2)
132
133
        assert(F.allclose(n_grad1, n_grad2))

134
135
136
137
138
139
140
141
    _test('sum', False)
    _test('max', False)
    _test('mean', False)
    _test('sum', True)
    _test('max', True)
    _test('mean', True)


142
def test_copy_edge_reduce():
143
    def _test(red, partial):
144
        g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1))
145
146
        # NOTE(zihao): add self-loop to avoid zero-degree nodes.
        g.add_edges(g.nodes(), g.nodes())
147
        g = g.to(F.ctx())
148
        hu, hv, he = generate_feature(g, 'none', 'none')
149
        if partial:
150
            nid = F.tensor(list(range(0, 100, 2)), g.idtype)
151

152
153
154
155
156
        g.ndata['u'] = F.attach_grad(F.clone(hu))
        g.ndata['v'] = F.attach_grad(F.clone(hv))
        g.edata['e'] = F.attach_grad(F.clone(he))

        with F.record_grad():
157
            if partial:
158
                g.pull(nid, fn.copy_e(e='e', out='m'),
159
160
                       builtin[red](msg='m', out='r1'))
            else:
161
                g.update_all(fn.copy_e(e='e', out='m'),
162
                             builtin[red](msg='m', out='r1'))
163
            r1 = g.ndata['r1']
VoVAllen's avatar
VoVAllen committed
164
            F.backward(F.reduce_sum(r1))
165
166
167
168
169
170
171
172
            e_grad1 = F.grad(g.edata['e'])

        # reset grad
        g.ndata['u'] = F.attach_grad(F.clone(hu))
        g.ndata['v'] = F.attach_grad(F.clone(hv))
        g.edata['e'] = F.attach_grad(F.clone(he))

        with F.record_grad():
173
174
175
176
            if partial:
                g.pull(nid, udf_copy_edge, udf_reduce[red])
            else:
                g.update_all(udf_copy_edge, udf_reduce[red])
177
            r2 = g.ndata['r2']
VoVAllen's avatar
VoVAllen committed
178
            F.backward(F.reduce_sum(r2))
179
180
            e_grad2 = F.grad(g.edata['e'])

181
182
183
        def _print_error(a, b):
            print("ERROR: Test copy_edge_{} partial: {}".
                  format(red, partial))
184
            return
185
186
187
188
189
190
            for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
                if not np.allclose(x, y):
                    print('@{} {} v.s. {}'.format(i, x, y))

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
191
        assert F.allclose(r1, r2)
192
193
194
        if not F.allclose(e_grad1, e_grad2):
            print('edge gradient')
            _print_error(e_grad1, e_grad2)
195
196
        assert(F.allclose(e_grad1, e_grad2))

197
198
199
200
201
202
    _test('sum', False)
    _test('max', False)
    _test('mean', False)
    _test('sum', True)
    _test('max', True)
    _test('mean', True)
203
204
205


def test_all_binary_builtins():
206
207
    def _test(g, lhs, rhs, binary_op, reducer, partial, nid, broadcast='none'):
        # initialize node/edge features with uniform(-1, 1)
208
        hu, hv, he = generate_feature(g, broadcast, binary_op)
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        if binary_op == 'div':
            # op = div
            # lhs range: [-1, 1]
            # rhs range: [1, 2]
            # result range: [-1, 1]
            if rhs == 'u':
                hu = (hu + 3) / 2
            elif rhs == 'v':
                hv = (hv + 3) / 2
            elif rhs == 'e':
                he = (he + 3) / 2

        if binary_op == 'add' or binary_op == 'sub':
            # op = add, sub
            # lhs range: [-1/2, 1/2]
            # rhs range: [-1/2, 1/2]
            # result range: [-1, 1]
            hu = hu / 2
            hv = hv / 2
            he = he / 2

230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        g.ndata['u'] = F.attach_grad(F.clone(hu))
        g.ndata['v'] = F.attach_grad(F.clone(hv))
        g.edata['e'] = F.attach_grad(F.clone(he))

        builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
        builtin_msg = getattr(fn, builtin_msg_name)
        builtin_red = getattr(fn, reducer)

        def target_feature_switch(g, target):
            if target == "u":
                return g.ndata["u"]
            elif target == "v":
                return g.ndata["v"]
            else:
                return g.edata["e"]

        with F.record_grad():
247
248
249
250
251
            if partial:
                g.pull(nid, builtin_msg(lhs, rhs, 'm'), builtin_red('m', 'r1'))
            else:
                g.update_all(builtin_msg(lhs, rhs, 'm'), builtin_red('m', 'r1'))
            r1 = g.ndata.pop('r1')
VoVAllen's avatar
VoVAllen committed
252
            F.backward(F.reduce_sum(r1))
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
            lhs_grad_1 = F.grad(target_feature_switch(g, lhs))
            rhs_grad_1 = F.grad(target_feature_switch(g, rhs))

        # reset grad
        g.ndata['u'] = F.attach_grad(F.clone(hu))
        g.ndata['v'] = F.attach_grad(F.clone(hv))
        g.edata['e'] = F.attach_grad(F.clone(he))

        def target_switch(edges, target):
            if target == "u":
                return edges.src
            elif target == "v":
                return edges.dst
            elif target == "e":
                return edges.data
            else:
                assert(0), "Unknown target {}".format(target)

        def mfunc(edges):
            op = getattr(F, binary_op)
273
274
275
276
277
278
279
280
281
            lhs_data = target_switch(edges, lhs)[lhs]
            rhs_data = target_switch(edges, rhs)[rhs]
            # NOTE(zihao): we need to do batched broadcast
            # e.g. (68, 3, 1) op (68, 5, 3, 4)
            while F.ndim(lhs_data) < F.ndim(rhs_data):
                lhs_data = F.unsqueeze(lhs_data, 1)
            while F.ndim(rhs_data) < F.ndim(lhs_data):
                rhs_data = F.unsqueeze(rhs_data, 1)
            return {"m": op(lhs_data, rhs_data)}
282
283
284
285
286
287

        def rfunc(nodes):
            op = getattr(F, reducer)
            return {"r2": op(nodes.mailbox['m'], 1)}

        with F.record_grad():
288
289
290
291
292
            if partial:
                g.pull(nid, mfunc, rfunc)
            else:
                g.update_all(mfunc, rfunc)
            r2 = g.ndata.pop('r2')
VoVAllen's avatar
VoVAllen committed
293
            F.backward(F.reduce_sum(r2), F.tensor([1.]))
294
295
296
            lhs_grad_2 = F.grad(target_feature_switch(g, lhs))
            rhs_grad_2 = F.grad(target_feature_switch(g, rhs))

297
298
        rtol = 1e-4
        atol = 1e-4
299

300
        def _print_error(a, b):
301
302
            print("ERROR: Test {}_{}_{}_{} broadcast: {} partial: {}".
                  format(lhs, binary_op, rhs, reducer, broadcast, partial))
303
            return
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
            if lhs == 'u':
                lhs_data = hu
            elif lhs == 'v':
                lhs_data = hv
            elif lhs == 'e':
                lhs_data = he

            if rhs == 'u':
                rhs_data = hu
            elif rhs == 'v':
                rhs_data = hv
            elif rhs == 'e':
                rhs_data = he
            print("lhs", F.asnumpy(lhs_data).tolist())
            print("rhs", F.asnumpy(rhs_data).tolist())
319
            for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
320
321
322
                if not np.allclose(x, y, rtol, atol):
                    print('@{} {} v.s. {}'.format(i, x, y))

323
        if not F.allclose(r1, r2, rtol, atol):
324
            _print_error(r1, r2)
325
        assert F.allclose(r1, r2, rtol, atol)
326
327

        if not F.allclose(lhs_grad_1, lhs_grad_2, rtol, atol):
328
329
            print("left grad")
            _print_error(lhs_grad_1, lhs_grad_2)
330
        assert(F.allclose(lhs_grad_1, lhs_grad_2, rtol, atol))
331

332
        if not F.allclose(rhs_grad_1, rhs_grad_2, rtol, atol):
333
334
            print("right grad")
            _print_error(rhs_grad_1, rhs_grad_2)
335
        assert(F.allclose(rhs_grad_1, rhs_grad_2, rtol, atol))
336
337
338

    g = dgl.DGLGraph()
    g.add_nodes(20)
339
340
    # NOTE(zihao): add self-loop to avoid zero-degree nodes.
    g.add_edges(g.nodes(), g.nodes())
341
    for i in range(2, 18):
342
343
344
345
346
347
348
349
        g.add_edges(0, i)
        g.add_edges(1, i)
        g.add_edges(i, 18)
        g.add_edges(i, 19)
    g.add_edges(18, 0)
    g.add_edges(18, 1)
    g.add_edges(19, 0)
    g.add_edges(19, 1)
350
351
    g = g.to(F.ctx())
    nid = F.tensor([0, 1, 4, 5, 7, 12, 14, 15, 18, 19], g.idtype)
352
    target = ["u", "v", "e"]
353

354
355
356
    for lhs, rhs in product(target, target):
        if lhs == rhs:
            continue
357
358
        for binary_op in ["add", "sub", "mul", "div"]:
            for reducer in ["sum", "max", "min", "mean"]:
359
                for broadcast in ["none", lhs, rhs]:
360
                    for partial in [False, True]:
361
                        print(lhs, rhs, binary_op, reducer, broadcast, partial)
362
363
                        _test(g, lhs, rhs, binary_op, reducer, partial, nid,
                              broadcast=broadcast)
364

nv-dlasalle's avatar
nv-dlasalle committed
365
@parametrize_idtype
366
367
368
369
370
371
372
373
374
@pytest.mark.parametrize('g', get_cases(['homo-zero-degree']))
def test_mean_zero_degree(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    g.ndata['h'] = F.ones((g.number_of_nodes(), 3))
    g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'x'))
    deg = F.asnumpy(g.in_degrees())
    v = F.tensor(np.where(deg == 0)[0])
    assert F.allclose(F.gather_row(g.ndata['x'], v), F.zeros((len(v), 3)))

375
if __name__ == '__main__':
VoVAllen's avatar
VoVAllen committed
376
377
    test_copy_src_reduce()
    test_copy_edge_reduce()
378
    test_all_binary_builtins()