test_specialization.py 16.7 KB
Newer Older
1
2
import numpy as np
import scipy.sparse as sp
3
4
import dgl
import dgl.function as fn
5
import backend as F
6
from test_utils import parametrize_dtype
7

Minjie Wang's avatar
Minjie Wang committed
8
9
D = 5

10
def generate_graph(idtype):
11
    g = dgl.DGLGraph()
12
    g = g.astype(idtype).to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
13
    g.add_nodes(10)
14
15
16
17
18
19
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
20
    g.ndata.update({'f1' : F.randn((10,)), 'f2' : F.randn((10, D))})
21
    weights = F.randn((17,))
22
    g.edata.update({'e1': weights, 'e2': F.unsqueeze(weights, 1)})
23
24
    return g

25
26
@parametrize_dtype
def test_v2v_update_all(idtype):
27
    def _test(fld):
28
29
        def message_func(edges):
            return {'m' : edges.src[fld]}
30

31
32
33
        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m' : edges.src[fld] * edges.data['e1']}
34
            else:
35
                return {'m' : edges.src[fld] * edges.data['e2']}
36

37
        def reduce_func(nodes):
38
            return {fld : F.sum(nodes.mailbox['m'], 1)}
39

40
41
        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}
42
        g = generate_graph(idtype)
43
        # update all
44
        v1 = g.ndata[fld]
Minjie Wang's avatar
Minjie Wang committed
45
        g.update_all(fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out=fld), apply_func)
46
        v2 = g.ndata[fld]
47
        g.ndata.update({fld : v1})
Minjie Wang's avatar
Minjie Wang committed
48
        g.update_all(message_func, reduce_func, apply_func)
49
        v3 = g.ndata[fld]
50
        assert F.allclose(v2, v3)
51
        # update all with edge weights
52
        v1 = g.ndata[fld]
Minjie Wang's avatar
Minjie Wang committed
53
54
        g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'),
                fn.sum(msg='m', out=fld), apply_func)
55
        v2 = g.ndata[fld]
56
        g.ndata.update({fld : v1})
Minjie Wang's avatar
Minjie Wang committed
57
        g.update_all(message_func_edge, reduce_func, apply_func)
58
        v4 = g.ndata[fld]
59
        assert F.allclose(v2, v4)
60
61
62
63
64
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')

65
66
67
68
@parametrize_dtype
def test_v2v_snr(idtype):
    u = F.tensor([0, 0, 0, 3, 4, 9], idtype)
    v = F.tensor([1, 2, 3, 9, 9, 0], idtype)
69
    def _test(fld):
70
71
        def message_func(edges):
            return {'m' : edges.src[fld]}
72

73
74
75
        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m' : edges.src[fld] * edges.data['e1']}
76
            else:
77
                return {'m' : edges.src[fld] * edges.data['e2']}
78

79
        def reduce_func(nodes):
80
            return {fld : F.sum(nodes.mailbox['m'], 1)}
81

82
83
        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}
84
        g = generate_graph(idtype)
85
        # send and recv
86
87
        v1 = g.ndata[fld]
        g.send_and_recv((u, v), fn.copy_src(src=fld, out='m'),
Minjie Wang's avatar
Minjie Wang committed
88
                fn.sum(msg='m', out=fld), apply_func)
89
        v2 = g.ndata[fld]
90
        g.ndata.update({fld : v1})
91
92
        g.send_and_recv((u, v), message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
93
        assert F.allclose(v2, v3)
94
        # send and recv with edge weights
95
96
        v1 = g.ndata[fld]
        g.send_and_recv((u, v), fn.src_mul_edge(src=fld, edge='e1', out='m'),
Minjie Wang's avatar
Minjie Wang committed
97
                fn.sum(msg='m', out=fld), apply_func)
98
        v2 = g.ndata[fld]
99
        g.ndata.update({fld : v1})
100
101
        g.send_and_recv((u, v), message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
102
        assert F.allclose(v2, v4)
103
104
105
106
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')
107

108

109
110
111
@parametrize_dtype
def test_v2v_pull(idtype):
    nodes = F.tensor([1, 2, 3, 9], idtype)
112
113
114
115
116
117
118
119
120
121
122
    def _test(fld):
        def message_func(edges):
            return {'m' : edges.src[fld]}

        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m' : edges.src[fld] * edges.data['e1']}
            else:
                return {'m' : edges.src[fld] * edges.data['e2']}

        def reduce_func(nodes):
123
            return {fld : F.sum(nodes.mailbox['m'], 1)}
124
125
126

        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}
127
        g = generate_graph(idtype)
128
129
130
131
132
133
134
        # send and recv
        v1 = g.ndata[fld]
        g.pull(nodes, fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out=fld), apply_func)
        v2 = g.ndata[fld]
        g.ndata[fld] = v1
        g.pull(nodes, message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
135
        assert F.allclose(v2, v3)
136
137
138
        # send and recv with edge weights
        v1 = g.ndata[fld]
        g.pull(nodes, fn.src_mul_edge(src=fld, edge='e1', out='m'),
139
                fn.sum(msg='m', out=fld), apply_func)
140
141
142
143
        v2 = g.ndata[fld]
        g.ndata[fld] = v1
        g.pull(nodes, message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
144
        assert F.allclose(v2, v4)
145
146
147
148
149
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')

150
151
@parametrize_dtype
def test_v2v_update_all_multi_fn(idtype):
152
153
    def message_func(edges):
        return {'m2': edges.src['f2']}
154

155
156
    def message_func_edge(edges):
        return {'m2': edges.src['f2'] * edges.data['e2']}
157

158
    def reduce_func(nodes):
159
        return {'v1': F.sum(nodes.mailbox['m2'], 1)}
160

161
162
    g = generate_graph(idtype)
    g.ndata.update({'v1' : F.zeros((10,)), 'v2' : F.zeros((10,))})
163
164
    fld = 'f2'

165
    g.update_all(message_func, reduce_func)
166
    v1 = g.ndata['v1']
167

Minjie Wang's avatar
Minjie Wang committed
168
    # 1 message, 2 reduces
169
    g.update_all(fn.copy_src(src=fld, out='m'), [fn.sum(msg='m', out='v2'), fn.sum(msg='m', out='v3')])
170
171
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
172
173
    assert F.allclose(v1, v2)
    assert F.allclose(v1, v3)
174
175
176

    # update all with edge weights, 2 message, 3 reduces
    g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
Minjie Wang's avatar
Minjie Wang committed
177
                 [fn.sum(msg='m1', out='v1'), fn.sum(msg='m2', out='v2'), fn.sum(msg='m1', out='v3')],
Minjie Wang's avatar
Minjie Wang committed
178
                 None)
179
180
181
    v1 = g.ndata['v1']
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
182
183
    assert F.allclose(v1, v2)
    assert F.allclose(v1, v3)
184
185

    # run UDF with single message and reduce
Minjie Wang's avatar
Minjie Wang committed
186
    g.update_all(message_func_edge, reduce_func, None)
187
    v2 = g.ndata['v2']
188
    assert F.allclose(v1, v2)
189

190
191
192
193
@parametrize_dtype
def test_v2v_snr_multi_fn(idtype):
    u = F.tensor([0, 0, 0, 3, 4, 9], idtype)
    v = F.tensor([1, 2, 3, 9, 9, 0], idtype)
194

195
196
    def message_func(edges):
        return {'m2': edges.src['f2']}
197

198
199
    def message_func_edge(edges):
        return {'m2': edges.src['f2'] * edges.data['e2']}
200

201
    def reduce_func(nodes):
202
        return {'v1' : F.sum(nodes.mailbox['m2'], 1)}
203

204
205
    g = generate_graph(idtype)
    g.ndata.update({'v1' : F.zeros((10, D)), 'v2' : F.zeros((10, D)),
206
        'v3' : F.zeros((10, D))})
207
208
    fld = 'f2'

209
    g.send_and_recv((u, v), message_func, reduce_func)
210
    v1 = g.ndata['v1']
211

Minjie Wang's avatar
Minjie Wang committed
212
    # 1 message, 2 reduces
213
    g.send_and_recv((u, v),
Minjie Wang's avatar
Minjie Wang committed
214
215
216
            fn.copy_src(src=fld, out='m'),
            [fn.sum(msg='m', out='v2'), fn.sum(msg='m', out='v3')],
            None)
217
218
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
219
220
    assert F.allclose(v1, v2)
    assert F.allclose(v1, v3)
221
222

    # send and recv with edge weights, 2 message, 3 reduces
223
    g.send_and_recv((u, v),
224
                    [fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
Minjie Wang's avatar
Minjie Wang committed
225
                    [fn.sum(msg='m1', out='v1'), fn.sum(msg='m2', out='v2'), fn.sum(msg='m1', out='v3')],
Minjie Wang's avatar
Minjie Wang committed
226
                    None)
227
228
229
    v1 = g.ndata['v1']
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
230
231
    assert F.allclose(v1, v2)
    assert F.allclose(v1, v3)
232
233

    # run UDF with single message and reduce
234
    g.send_and_recv((u, v), message_func_edge,
Minjie Wang's avatar
Minjie Wang committed
235
            reduce_func, None)
236
    v2 = g.ndata['v2']
237
    assert F.allclose(v1, v2)
238

239
240
@parametrize_dtype
def test_e2v_update_all_multi_fn(idtype):
241
242
243
244
245
246
    def _test(fld):
        def message_func(edges):
            return {'m1' : edges.src[fld] + edges.dst[fld],
                    'm2' : edges.src[fld] * edges.dst[fld]}

        def reduce_func(nodes):
247
            return {fld : F.sum(nodes.mailbox['m1'] + nodes.mailbox['m2'], 1)}
248
249
250
251
252
253
254

        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}

        def apply_func_2(nodes):
            return {fld : 2 * nodes.data['r1'] + 2 * nodes.data['r2']}

255
        g = generate_graph(idtype)
256
        # update all
257
        v1 = g.ndata[fld]
258
259
        # no specialization
        g.update_all(message_func, reduce_func, apply_func)
260
        v2 = g.ndata[fld]
261
262

        # user break reduce func into 2 builtin
263
        g.ndata.update({fld : v1})
264
265
266
        g.update_all(message_func,
                     [fn.sum(msg='m1', out='r1'), fn.sum(msg='m2', out='r2')],
                     apply_func_2)
267
        v3 = g.ndata[fld]
268

269
        assert F.allclose(v2, v3)
270
271
272
273
274
275

    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')

276
277
278
279
@parametrize_dtype
def test_e2v_snr_multi_fn(idtype):
    u = F.tensor([0, 0, 0, 3, 4, 9], idtype)
    v = F.tensor([1, 2, 3, 9, 9, 0], idtype)
280
281
282
283
284
285
    def _test(fld):
        def message_func(edges):
            return {'m1' : edges.src[fld] + edges.dst[fld],
                    'm2' : edges.src[fld] * edges.dst[fld]}

        def reduce_func(nodes):
286
            return {fld : F.sum(nodes.mailbox['m1'] + nodes.mailbox['m2'], 1)}
287
288
289
290
291
292
293

        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}

        def apply_func_2(nodes):
            return {fld : 2 * nodes.data['r1'] + 2 * nodes.data['r2']}

294
        g = generate_graph(idtype)
295
        # send_and_recv
296
        v1 = g.ndata[fld]
297
298
        # no specialization
        g.send_and_recv((u, v), message_func, reduce_func, apply_func)
299
        v2 = g.ndata[fld]
300
301

        # user break reduce func into 2 builtin
302
        g.ndata.update({fld : v1})
303
304
305
        g.send_and_recv((u, v), message_func,
                        [fn.sum(msg='m1', out='r1'), fn.sum(msg='m2', out='r2')],
                        apply_func_2)
306
        v3 = g.ndata[fld]
307

308
        assert F.allclose(v2, v3)
309
310
311
312
313
314

    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')

315
316
@parametrize_dtype
def test_update_all_multi_fallback(idtype):
317
318
    # create a graph with zero in degree nodes
    g = dgl.DGLGraph()
319
    g = g.astype(idtype).to(F.ctx())
320
321
322
323
    g.add_nodes(10)
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
324
325
326
    g.ndata['h'] = F.randn((10, D))
    g.edata['w1'] = F.randn((16,))
    g.edata['w2'] = F.randn((16, D))
327
    def _mfunc_hxw1(edges):
328
        return {'m1' : edges.src['h'] * F.unsqueeze(edges.data['w1'], 1)}
329
330
331
    def _mfunc_hxw2(edges):
        return {'m2' : edges.src['h'] * edges.data['w2']}
    def _rfunc_m1(nodes):
332
        return {'o1' : F.sum(nodes.mailbox['m1'], 1)}
333
    def _rfunc_m2(nodes):
334
        return {'o2' : F.sum(nodes.mailbox['m2'], 1)}
335
    def _rfunc_m1max(nodes):
336
        return {'o3' : F.max(nodes.mailbox['m1'], 1)}
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    def _afunc(nodes):
        ret = {}
        for k, v in nodes.data.items():
            if k.startswith('o'):
                ret[k] = 2 * v
        return ret
    # compute ground truth
    g.update_all(_mfunc_hxw1, _rfunc_m1, _afunc)
    o1 = g.ndata.pop('o1')
    g.update_all(_mfunc_hxw2, _rfunc_m2, _afunc)
    o2 = g.ndata.pop('o2')
    g.update_all(_mfunc_hxw1, _rfunc_m1max, _afunc)
    o3 = g.ndata.pop('o3')
    # v2v spmv
    g.update_all(fn.src_mul_edge(src='h', edge='w1', out='m1'),
                 fn.sum(msg='m1', out='o1'),
                 _afunc)
354
    assert F.allclose(o1, g.ndata.pop('o1'))
355
356
357
358
    # v2v fallback to e2v
    g.update_all(fn.src_mul_edge(src='h', edge='w2', out='m2'),
                 fn.sum(msg='m2', out='o2'),
                 _afunc)
359
    assert F.allclose(o2, g.ndata.pop('o2'))
360
361
362
363
    # multi builtins, both v2v spmv
    g.update_all([fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w1', out='m2')],
                 [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
                 _afunc)
364
365
    assert F.allclose(o1, g.ndata.pop('o1'))
    assert F.allclose(o1, g.ndata.pop('o2'))
366
367
368
369
    # multi builtins, one v2v spmv, one fallback to e2v
    g.update_all([fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w2', out='m2')],
                 [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
                 _afunc)
370
371
    assert F.allclose(o1, g.ndata.pop('o1'))
    assert F.allclose(o2, g.ndata.pop('o2'))
372

373
374
@parametrize_dtype
def test_pull_multi_fallback(idtype):
375
376
    # create a graph with zero in degree nodes
    g = dgl.DGLGraph()
377
    g = g.astype(idtype).to(F.ctx())
378
379
380
381
    g.add_nodes(10)
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
382
383
384
    g.ndata['h'] = F.randn((10, D))
    g.edata['w1'] = F.randn((16,))
    g.edata['w2'] = F.randn((16, D))
385
    def _mfunc_hxw1(edges):
386
        return {'m1' : edges.src['h'] * F.unsqueeze(edges.data['w1'], 1)}
387
388
389
    def _mfunc_hxw2(edges):
        return {'m2' : edges.src['h'] * edges.data['w2']}
    def _rfunc_m1(nodes):
390
        return {'o1' : F.sum(nodes.mailbox['m1'], 1)}
391
    def _rfunc_m2(nodes):
392
        return {'o2' : F.sum(nodes.mailbox['m2'], 1)}
393
    def _rfunc_m1max(nodes):
394
        return {'o3' : F.max(nodes.mailbox['m1'], 1)}
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    def _afunc(nodes):
        ret = {}
        for k, v in nodes.data.items():
            if k.startswith('o'):
                ret[k] = 2 * v
        return ret
    # nodes to pull
    def _pull_nodes(nodes):
        # compute ground truth
        g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc)
        o1 = g.ndata.pop('o1')
        g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc)
        o2 = g.ndata.pop('o2')
        g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc)
        o3 = g.ndata.pop('o3')
        # v2v spmv
        g.pull(nodes, fn.src_mul_edge(src='h', edge='w1', out='m1'),
                     fn.sum(msg='m1', out='o1'),
                     _afunc)
414
        assert F.allclose(o1, g.ndata.pop('o1'))
415
416
417
418
        # v2v fallback to e2v
        g.pull(nodes, fn.src_mul_edge(src='h', edge='w2', out='m2'),
                     fn.sum(msg='m2', out='o2'),
                     _afunc)
419
        assert F.allclose(o2, g.ndata.pop('o2'))
420
421
422
423
424
        # multi builtins, both v2v spmv
        g.pull(nodes,
               [fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w1', out='m2')],
               [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
               _afunc)
425
426
        assert F.allclose(o1, g.ndata.pop('o1'))
        assert F.allclose(o1, g.ndata.pop('o2'))
427
428
429
430
431
        # multi builtins, one v2v spmv, one fallback to e2v
        g.pull(nodes,
               [fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w2', out='m2')],
               [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
               _afunc)
432
433
        assert F.allclose(o1, g.ndata.pop('o1'))
        assert F.allclose(o2, g.ndata.pop('o2'))
434
435
436
437
438
439
440
    # test#1: non-0deg nodes
    nodes = [1, 2, 9]
    _pull_nodes(nodes)
    # test#2: 0deg nodes + non-0deg nodes
    nodes = [0, 1, 2, 9]
    _pull_nodes(nodes)

441
442
@parametrize_dtype
def test_spmv_3d_feat(idtype):
443
    def src_mul_edge_udf(edges):
444
        return {'sum': edges.src['h'] * F.unsqueeze(F.unsqueeze(edges.data['h'], 1), 1)}
445
446

    def sum_udf(nodes):
447
        return {'h': F.sum(nodes.mailbox['sum'], 1)}
448
449
450
451
452

    n = 100
    p = 0.1
    a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n))
    g = dgl.DGLGraph(a)
453
    g = g.astype(idtype).to(F.ctx())
454
455
456
    m = g.number_of_edges()

    # test#1: v2v with adj data
457
458
    h = F.randn((n, 5, 5))
    e = F.randn((m,))
459
460
461
462
463
464
465
466
467

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=fn.src_mul_edge('h', 'h', 'sum'), reduce_func=fn.sum('sum', 'h')) # 1
    ans = g.ndata['h']

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=src_mul_edge_udf, reduce_func=fn.sum('sum', 'h')) # 2
468
    assert F.allclose(g.ndata['h'], ans)
469
470
471
472

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf) # 3
473
    assert F.allclose(g.ndata['h'], ans)
474
475
476
477
478

    # test#2: e2v
    def src_mul_edge_udf(edges):
        return {'sum': edges.src['h'] * edges.data['h']}

479
480
    h = F.randn((n, 5, 5))
    e = F.randn((m, 5, 5))
481
482
483
484
485
486
487
488
489

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=fn.src_mul_edge('h', 'h', 'sum'), reduce_func=fn.sum('sum', 'h')) # 1
    ans = g.ndata['h']

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=src_mul_edge_udf, reduce_func=fn.sum('sum', 'h')) # 2
490
    assert F.allclose(g.ndata['h'], ans)
491
492
493
494

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf) # 3
495
    assert F.allclose(g.ndata['h'], ans)
496

497
if __name__ == '__main__':
498
499
    test_v2v_update_all()
    test_v2v_snr()
500
    test_v2v_pull()
501
502
503
504
505
    test_v2v_update_all_multi_fn()
    test_v2v_snr_multi_fn()
    test_e2v_update_all_multi_fn()
    test_e2v_snr_multi_fn()
    test_e2v_recv_multi_fn()
506
507
    test_update_all_multi_fallback()
    test_pull_multi_fallback()
508
    test_spmv_3d_feat()