test_specialization.py 19.6 KB
Newer Older
1
2
import numpy as np
import scipy.sparse as sp
3
4
import dgl
import dgl.function as fn
5
import backend as F
6

Minjie Wang's avatar
Minjie Wang committed
7
8
D = 5

9
def generate_graph():
10
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
11
    g.add_nodes(10)
12
13
14
15
16
17
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
18
19
20
    g.set_n_repr({'f1' : F.randn((10,)), 'f2' : F.randn((10, D))})
    weights = F.randn((17,))
    g.set_e_repr({'e1': weights, 'e2': F.unsqueeze(weights, 1)})
21
22
    return g

23
def test_v2v_update_all():
24
    def _test(fld):
25
26
        def message_func(edges):
            return {'m' : edges.src[fld]}
27

28
29
30
        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m' : edges.src[fld] * edges.data['e1']}
31
            else:
32
                return {'m' : edges.src[fld] * edges.data['e2']}
33

34
        def reduce_func(nodes):
35
            return {fld : F.sum(nodes.mailbox['m'], 1)}
36

37
38
        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}
39
40
        g = generate_graph()
        # update all
41
        v1 = g.ndata[fld]
Minjie Wang's avatar
Minjie Wang committed
42
        g.update_all(fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out=fld), apply_func)
43
        v2 = g.ndata[fld]
44
        g.set_n_repr({fld : v1})
Minjie Wang's avatar
Minjie Wang committed
45
        g.update_all(message_func, reduce_func, apply_func)
46
        v3 = g.ndata[fld]
47
        assert F.allclose(v2, v3)
48
        # update all with edge weights
49
        v1 = g.ndata[fld]
Minjie Wang's avatar
Minjie Wang committed
50
51
        g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'),
                fn.sum(msg='m', out=fld), apply_func)
52
        v2 = g.ndata[fld]
53
        g.set_n_repr({fld : v1})
Minjie Wang's avatar
Minjie Wang committed
54
55
        g.update_all(fn.src_mul_edge(src=fld, edge='e2', out='m'),
                fn.sum(msg='m', out=fld), apply_func)
56
        v3 = g.ndata[fld]
57
        g.set_n_repr({fld : v1})
Minjie Wang's avatar
Minjie Wang committed
58
        g.update_all(message_func_edge, reduce_func, apply_func)
59
        v4 = g.ndata[fld]
60
61
        assert F.allclose(v2, v3)
        assert F.allclose(v3, v4)
62
63
64
65
66
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')

67
def test_v2v_snr():
68
69
    u = F.tensor([0, 0, 0, 3, 4, 9])
    v = F.tensor([1, 2, 3, 9, 9, 0])
70
    def _test(fld):
71
72
        def message_func(edges):
            return {'m' : edges.src[fld]}
73

74
75
76
        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m' : edges.src[fld] * edges.data['e1']}
77
            else:
78
                return {'m' : edges.src[fld] * edges.data['e2']}
79

80
        def reduce_func(nodes):
81
            return {fld : F.sum(nodes.mailbox['m'], 1)}
82

83
84
        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}
85
86
        g = generate_graph()
        # send and recv
87
88
        v1 = g.ndata[fld]
        g.send_and_recv((u, v), fn.copy_src(src=fld, out='m'),
Minjie Wang's avatar
Minjie Wang committed
89
                fn.sum(msg='m', out=fld), apply_func)
90
        v2 = g.ndata[fld]
91
        g.set_n_repr({fld : v1})
92
93
        g.send_and_recv((u, v), message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
94
        assert F.allclose(v2, v3)
95
        # send and recv with edge weights
96
97
        v1 = g.ndata[fld]
        g.send_and_recv((u, v), fn.src_mul_edge(src=fld, edge='e1', out='m'),
Minjie Wang's avatar
Minjie Wang committed
98
                fn.sum(msg='m', out=fld), apply_func)
99
        v2 = g.ndata[fld]
100
        g.set_n_repr({fld : v1})
101
        g.send_and_recv((u, v), fn.src_mul_edge(src=fld, edge='e2', out='m'),
Minjie Wang's avatar
Minjie Wang committed
102
                fn.sum(msg='m', out=fld), apply_func)
103
        v3 = g.ndata[fld]
104
        g.set_n_repr({fld : v1})
105
106
        g.send_and_recv((u, v), message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
107
108
        assert F.allclose(v2, v3)
        assert F.allclose(v3, v4)
109
110
111
112
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')
113

114
115

def test_v2v_pull():
116
    nodes = F.tensor([1, 2, 3, 9])
117
118
119
120
121
122
123
124
125
126
127
    def _test(fld):
        def message_func(edges):
            return {'m' : edges.src[fld]}

        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m' : edges.src[fld] * edges.data['e1']}
            else:
                return {'m' : edges.src[fld] * edges.data['e2']}

        def reduce_func(nodes):
128
            return {fld : F.sum(nodes.mailbox['m'], 1)}
129
130
131
132
133
134
135
136
137
138
139

        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}
        g = generate_graph()
        # send and recv
        v1 = g.ndata[fld]
        g.pull(nodes, fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out=fld), apply_func)
        v2 = g.ndata[fld]
        g.ndata[fld] = v1
        g.pull(nodes, message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
140
        assert F.allclose(v2, v3)
141
142
143
144
145
146
147
148
149
150
151
152
        # send and recv with edge weights
        v1 = g.ndata[fld]
        g.pull(nodes, fn.src_mul_edge(src=fld, edge='e1', out='m'),
               fn.sum(msg='m', out=fld), apply_func)
        v2 = g.ndata[fld]
        g.ndata[fld] = v1
        g.pull(nodes, fn.src_mul_edge(src=fld, edge='e2', out='m'),
               fn.sum(msg='m', out=fld), apply_func)
        v3 = g.ndata[fld]
        g.ndata[fld] = v1
        g.pull(nodes, message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
153
154
        assert F.allclose(v2, v3)
        assert F.allclose(v3, v4)
155
156
157
158
159
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')

160
def test_v2v_update_all_multi_fn():
161
162
    def message_func(edges):
        return {'m2': edges.src['f2']}
163

164
165
    def message_func_edge(edges):
        return {'m2': edges.src['f2'] * edges.data['e2']}
166

167
    def reduce_func(nodes):
168
        return {'v1': F.sum(nodes.mailbox['m2'], 1)}
169
170

    g = generate_graph()
171
    g.set_n_repr({'v1' : F.zeros((10,)), 'v2' : F.zeros((10,))})
172
173
    fld = 'f2'

174
    g.update_all(message_func, reduce_func)
175
    v1 = g.ndata['v1']
176

Minjie Wang's avatar
Minjie Wang committed
177
    # 1 message, 2 reduces
178
    g.update_all(fn.copy_src(src=fld, out='m'), [fn.sum(msg='m', out='v2'), fn.sum(msg='m', out='v3')])
179
180
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
181
182
    assert F.allclose(v1, v2)
    assert F.allclose(v1, v3)
183
184
185

    # update all with edge weights, 2 message, 3 reduces
    g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
Minjie Wang's avatar
Minjie Wang committed
186
                 [fn.sum(msg='m1', out='v1'), fn.sum(msg='m2', out='v2'), fn.sum(msg='m1', out='v3')],
Minjie Wang's avatar
Minjie Wang committed
187
                 None)
188
189
190
    v1 = g.ndata['v1']
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
191
192
    assert F.allclose(v1, v2)
    assert F.allclose(v1, v3)
193
194

    # run UDF with single message and reduce
Minjie Wang's avatar
Minjie Wang committed
195
    g.update_all(message_func_edge, reduce_func, None)
196
    v2 = g.ndata['v2']
197
    assert F.allclose(v1, v2)
198

199
def test_v2v_snr_multi_fn():
200
201
    u = F.tensor([0, 0, 0, 3, 4, 9])
    v = F.tensor([1, 2, 3, 9, 9, 0])
202

203
204
    def message_func(edges):
        return {'m2': edges.src['f2']}
205

206
207
    def message_func_edge(edges):
        return {'m2': edges.src['f2'] * edges.data['e2']}
208

209
    def reduce_func(nodes):
210
        return {'v1' : F.sum(nodes.mailbox['m2'], 1)}
211
212

    g = generate_graph()
213
214
    g.set_n_repr({'v1' : F.zeros((10, D)), 'v2' : F.zeros((10, D)),
        'v3' : F.zeros((10, D))})
215
216
    fld = 'f2'

217
    g.send_and_recv((u, v), message_func, reduce_func)
218
    v1 = g.ndata['v1']
219

Minjie Wang's avatar
Minjie Wang committed
220
    # 1 message, 2 reduces
221
    g.send_and_recv((u, v),
Minjie Wang's avatar
Minjie Wang committed
222
223
224
            fn.copy_src(src=fld, out='m'),
            [fn.sum(msg='m', out='v2'), fn.sum(msg='m', out='v3')],
            None)
225
226
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
227
228
    assert F.allclose(v1, v2)
    assert F.allclose(v1, v3)
229
230

    # send and recv with edge weights, 2 message, 3 reduces
231
    g.send_and_recv((u, v),
232
                    [fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
Minjie Wang's avatar
Minjie Wang committed
233
                    [fn.sum(msg='m1', out='v1'), fn.sum(msg='m2', out='v2'), fn.sum(msg='m1', out='v3')],
Minjie Wang's avatar
Minjie Wang committed
234
                    None)
235
236
237
    v1 = g.ndata['v1']
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
238
239
    assert F.allclose(v1, v2)
    assert F.allclose(v1, v3)
240
241

    # run UDF with single message and reduce
242
    g.send_and_recv((u, v), message_func_edge,
Minjie Wang's avatar
Minjie Wang committed
243
            reduce_func, None)
244
    v2 = g.ndata['v2']
245
    assert F.allclose(v1, v2)
246

247
248
249
250
251
252
253
def test_e2v_update_all_multi_fn():
    def _test(fld):
        def message_func(edges):
            return {'m1' : edges.src[fld] + edges.dst[fld],
                    'm2' : edges.src[fld] * edges.dst[fld]}

        def reduce_func(nodes):
254
            return {fld : F.sum(nodes.mailbox['m1'] + nodes.mailbox['m2'], 1)}
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}

        def apply_func_2(nodes):
            return {fld : 2 * nodes.data['r1'] + 2 * nodes.data['r2']}

        g = generate_graph()
        # update all
        v1 = g.get_n_repr()[fld]
        # no specialization
        g.update_all(message_func, reduce_func, apply_func)
        v2 = g.get_n_repr()[fld]

        # user break reduce func into 2 builtin
        g.set_n_repr({fld : v1})
        g.update_all(message_func,
                     [fn.sum(msg='m1', out='r1'), fn.sum(msg='m2', out='r2')],
                     apply_func_2)
        v3 = g.get_n_repr()[fld]

276
        assert F.allclose(v2, v3)
277
278
279
280
281
282
283

    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')

def test_e2v_snr_multi_fn():
284
285
    u = F.tensor([0, 0, 0, 3, 4, 9])
    v = F.tensor([1, 2, 3, 9, 9, 0])
286
287
288
289
290
291
    def _test(fld):
        def message_func(edges):
            return {'m1' : edges.src[fld] + edges.dst[fld],
                    'm2' : edges.src[fld] * edges.dst[fld]}

        def reduce_func(nodes):
292
            return {fld : F.sum(nodes.mailbox['m1'] + nodes.mailbox['m2'], 1)}
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313

        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}

        def apply_func_2(nodes):
            return {fld : 2 * nodes.data['r1'] + 2 * nodes.data['r2']}

        g = generate_graph()
        # send_and_recv
        v1 = g.get_n_repr()[fld]
        # no specialization
        g.send_and_recv((u, v), message_func, reduce_func, apply_func)
        v2 = g.get_n_repr()[fld]

        # user break reduce func into 2 builtin
        g.set_n_repr({fld : v1})
        g.send_and_recv((u, v), message_func,
                        [fn.sum(msg='m1', out='r1'), fn.sum(msg='m2', out='r2')],
                        apply_func_2)
        v3 = g.get_n_repr()[fld]

314
        assert F.allclose(v2, v3)
315
316
317
318
319
320
321

    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')

def test_e2v_recv_multi_fn():
322
323
    u = F.tensor([0, 0, 0, 3, 4, 9])
    v = F.tensor([1, 2, 3, 9, 9, 0])
324
325
326
327
328
329
    def _test(fld):
        def message_func(edges):
            return {'m1' : edges.src[fld] + edges.dst[fld],
                    'm2' : edges.src[fld] * edges.dst[fld]}

        def reduce_func(nodes):
330
            return {fld : F.sum(nodes.mailbox['m1'] + nodes.mailbox['m2'], 1)}
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353

        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}

        def apply_func_2(nodes):
            return {fld : 2 * nodes.data['r1'] + 2 * nodes.data['r2']}

        g = generate_graph()
        # recv
        v1 = g.get_n_repr()[fld]
        # no specialization
        g.send((u, v), message_func)
        g.recv([0,1,2,3,9], reduce_func, apply_func)
        v2 = g.get_n_repr()[fld]

        # user break reduce func into 2 builtin
        g.set_n_repr({fld : v1})
        g.send((u, v), message_func)
        g.recv([0,1,2,3,9],
               [fn.sum(msg='m1', out='r1'), fn.sum(msg='m2', out='r2')],
               apply_func_2)
        v3 = g.get_n_repr()[fld]

354
        assert F.allclose(v2, v3)
355
356
357
358
359
360

    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')

361
def test_update_all_multi_fallback():
362
363
364
365
366
367
    # create a graph with zero in degree nodes
    g = dgl.DGLGraph()
    g.add_nodes(10)
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
368
369
370
    g.ndata['h'] = F.randn((10, D))
    g.edata['w1'] = F.randn((16,))
    g.edata['w2'] = F.randn((16, D))
371
    def _mfunc_hxw1(edges):
372
        return {'m1' : edges.src['h'] * F.unsqueeze(edges.data['w1'], 1)}
373
374
375
    def _mfunc_hxw2(edges):
        return {'m2' : edges.src['h'] * edges.data['w2']}
    def _rfunc_m1(nodes):
376
        return {'o1' : F.sum(nodes.mailbox['m1'], 1)}
377
    def _rfunc_m2(nodes):
378
        return {'o2' : F.sum(nodes.mailbox['m2'], 1)}
379
    def _rfunc_m1max(nodes):
380
        return {'o3' : F.max(nodes.mailbox['m1'], 1)}
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    def _afunc(nodes):
        ret = {}
        for k, v in nodes.data.items():
            if k.startswith('o'):
                ret[k] = 2 * v
        return ret
    # compute ground truth
    g.update_all(_mfunc_hxw1, _rfunc_m1, _afunc)
    o1 = g.ndata.pop('o1')
    g.update_all(_mfunc_hxw2, _rfunc_m2, _afunc)
    o2 = g.ndata.pop('o2')
    g.update_all(_mfunc_hxw1, _rfunc_m1max, _afunc)
    o3 = g.ndata.pop('o3')
    # v2v spmv
    g.update_all(fn.src_mul_edge(src='h', edge='w1', out='m1'),
                 fn.sum(msg='m1', out='o1'),
                 _afunc)
398
    assert F.allclose(o1, g.ndata.pop('o1'))
399
400
401
402
    # v2v fallback to e2v
    g.update_all(fn.src_mul_edge(src='h', edge='w2', out='m2'),
                 fn.sum(msg='m2', out='o2'),
                 _afunc)
403
    assert F.allclose(o2, g.ndata.pop('o2'))
404
405
406
407
    # v2v fallback to degree bucketing
    g.update_all(fn.src_mul_edge(src='h', edge='w1', out='m1'),
                 fn.max(msg='m1', out='o3'),
                 _afunc)
408
    assert F.allclose(o3, g.ndata.pop('o3'))
409
410
411
412
    # multi builtins, both v2v spmv
    g.update_all([fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w1', out='m2')],
                 [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
                 _afunc)
413
414
    assert F.allclose(o1, g.ndata.pop('o1'))
    assert F.allclose(o1, g.ndata.pop('o2'))
415
416
417
418
    # multi builtins, one v2v spmv, one fallback to e2v
    g.update_all([fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w2', out='m2')],
                 [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
                 _afunc)
419
420
    assert F.allclose(o1, g.ndata.pop('o1'))
    assert F.allclose(o2, g.ndata.pop('o2'))
421
422
423
424
425
426
427
428
    # multi builtins, one v2v spmv, one fallback to e2v, one fallback to degree-bucketing
    g.update_all([fn.src_mul_edge(src='h', edge='w1', out='m1'),
                  fn.src_mul_edge(src='h', edge='w2', out='m2'),
                  fn.src_mul_edge(src='h', edge='w1', out='m3')],
                 [fn.sum(msg='m1', out='o1'),
                  fn.sum(msg='m2', out='o2'),
                  fn.max(msg='m3', out='o3')],
                 _afunc)
429
430
431
    assert F.allclose(o1, g.ndata.pop('o1'))
    assert F.allclose(o2, g.ndata.pop('o2'))
    assert F.allclose(o3, g.ndata.pop('o3'))
432

433
434
435
436
437
438
439
440

def test_pull_multi_fallback():
    # create a graph with zero in degree nodes
    g = dgl.DGLGraph()
    g.add_nodes(10)
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
441
442
443
    g.ndata['h'] = F.randn((10, D))
    g.edata['w1'] = F.randn((16,))
    g.edata['w2'] = F.randn((16, D))
444
    def _mfunc_hxw1(edges):
445
        return {'m1' : edges.src['h'] * F.unsqueeze(edges.data['w1'], 1)}
446
447
448
    def _mfunc_hxw2(edges):
        return {'m2' : edges.src['h'] * edges.data['w2']}
    def _rfunc_m1(nodes):
449
        return {'o1' : F.sum(nodes.mailbox['m1'], 1)}
450
    def _rfunc_m2(nodes):
451
        return {'o2' : F.sum(nodes.mailbox['m2'], 1)}
452
    def _rfunc_m1max(nodes):
453
        return {'o3' : F.max(nodes.mailbox['m1'], 1)}
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    def _afunc(nodes):
        ret = {}
        for k, v in nodes.data.items():
            if k.startswith('o'):
                ret[k] = 2 * v
        return ret
    # nodes to pull
    def _pull_nodes(nodes):
        # compute ground truth
        g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc)
        o1 = g.ndata.pop('o1')
        g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc)
        o2 = g.ndata.pop('o2')
        g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc)
        o3 = g.ndata.pop('o3')
        # v2v spmv
        g.pull(nodes, fn.src_mul_edge(src='h', edge='w1', out='m1'),
                     fn.sum(msg='m1', out='o1'),
                     _afunc)
473
        assert F.allclose(o1, g.ndata.pop('o1'))
474
475
476
477
        # v2v fallback to e2v
        g.pull(nodes, fn.src_mul_edge(src='h', edge='w2', out='m2'),
                     fn.sum(msg='m2', out='o2'),
                     _afunc)
478
        assert F.allclose(o2, g.ndata.pop('o2'))
479
480
481
482
        # v2v fallback to degree bucketing
        g.pull(nodes, fn.src_mul_edge(src='h', edge='w1', out='m1'),
                     fn.max(msg='m1', out='o3'),
                     _afunc)
483
        assert F.allclose(o3, g.ndata.pop('o3'))
484
485
486
487
488
        # multi builtins, both v2v spmv
        g.pull(nodes,
               [fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w1', out='m2')],
               [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
               _afunc)
489
490
        assert F.allclose(o1, g.ndata.pop('o1'))
        assert F.allclose(o1, g.ndata.pop('o2'))
491
492
493
494
495
        # multi builtins, one v2v spmv, one fallback to e2v
        g.pull(nodes,
               [fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w2', out='m2')],
               [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
               _afunc)
496
497
        assert F.allclose(o1, g.ndata.pop('o1'))
        assert F.allclose(o2, g.ndata.pop('o2'))
498
499
500
501
502
503
504
505
506
        # multi builtins, one v2v spmv, one fallback to e2v, one fallback to degree-bucketing
        g.pull(nodes,
               [fn.src_mul_edge(src='h', edge='w1', out='m1'),
                fn.src_mul_edge(src='h', edge='w2', out='m2'),
                fn.src_mul_edge(src='h', edge='w1', out='m3')],
               [fn.sum(msg='m1', out='o1'),
                fn.sum(msg='m2', out='o2'),
                fn.max(msg='m3', out='o3')],
               _afunc)
507
508
509
        assert F.allclose(o1, g.ndata.pop('o1'))
        assert F.allclose(o2, g.ndata.pop('o2'))
        assert F.allclose(o3, g.ndata.pop('o3'))
510
511
512
513
514
515
516
    # test#1: non-0deg nodes
    nodes = [1, 2, 9]
    _pull_nodes(nodes)
    # test#2: 0deg nodes + non-0deg nodes
    nodes = [0, 1, 2, 9]
    _pull_nodes(nodes)

517
518
def test_spmv_3d_feat():
    def src_mul_edge_udf(edges):
519
        return {'sum': edges.src['h'] * F.unsqueeze(F.unsqueeze(edges.data['h'], 1), 1)}
520
521

    def sum_udf(nodes):
522
        return {'h': F.sum(nodes.mailbox['sum'], 1)}
523
524
525
526
527
528
529
530

    n = 100
    p = 0.1
    a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n))
    g = dgl.DGLGraph(a)
    m = g.number_of_edges()

    # test#1: v2v with adj data
531
532
    h = F.randn((n, 5, 5))
    e = F.randn((m,))
533
534
535
536
537
538
539
540
541

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=fn.src_mul_edge('h', 'h', 'sum'), reduce_func=fn.sum('sum', 'h')) # 1
    ans = g.ndata['h']

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=src_mul_edge_udf, reduce_func=fn.sum('sum', 'h')) # 2
542
    assert F.allclose(g.ndata['h'], ans)
543
544
545
546

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf) # 3
547
    assert F.allclose(g.ndata['h'], ans)
548
549
550
551
552

    # test#2: e2v
    def src_mul_edge_udf(edges):
        return {'sum': edges.src['h'] * edges.data['h']}

553
554
    h = F.randn((n, 5, 5))
    e = F.randn((m, 5, 5))
555
556
557
558
559
560
561
562
563

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=fn.src_mul_edge('h', 'h', 'sum'), reduce_func=fn.sum('sum', 'h')) # 1
    ans = g.ndata['h']

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=src_mul_edge_udf, reduce_func=fn.sum('sum', 'h')) # 2
564
    assert F.allclose(g.ndata['h'], ans)
565
566
567
568

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf) # 3
569
    assert F.allclose(g.ndata['h'], ans)
570

571
if __name__ == '__main__':
572
573
    test_v2v_update_all()
    test_v2v_snr()
574
    test_v2v_pull()
575
576
577
578
579
    test_v2v_update_all_multi_fn()
    test_v2v_snr_multi_fn()
    test_e2v_update_all_multi_fn()
    test_e2v_snr_multi_fn()
    test_e2v_recv_multi_fn()
580
581
    test_update_all_multi_fallback()
    test_pull_multi_fallback()
582
    test_spmv_3d_feat()