test_specialization.py 1.21 KB
Newer Older
1
import torch as th
Minjie Wang's avatar
Minjie Wang committed
2
import numpy as np
3
4
from dgl.graph import DGLGraph

Minjie Wang's avatar
Minjie Wang committed
5
6
7
8
9
10
11
12
13
14
15
D = 5

def check_eq(a, b):
    if not np.allclose(a.numpy(), b.numpy()):
        print(a, b)

def message_func(hu, edge):
    return hu

def reduce_func(hv, msgs):
    return th.sum(msgs, 1)
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

def update_func(hv, accum):
    assert hv.shape == accum.shape
    return hv + accum

def generate_graph():
    g = DGLGraph()
    for i in range(10):
        g.add_node(i) # 10 nodes.
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
    col = th.randn(10, D)
Minjie Wang's avatar
Minjie Wang committed
32
    g.set_n_repr(col)
33
34
35
36
37
38
39
    return g

def test_spmv_specialize():
    g = generate_graph()
    g.register_message_func('from_src', batchable=True)
    g.register_reduce_func('sum', batchable=True)
    g.register_update_func(update_func, batchable=True)
Minjie Wang's avatar
Minjie Wang committed
40
41
42
43
44
45
    v1 = g.get_n_repr()
    g.update_all()
    v2 = g.get_n_repr()
    g.set_n_repr(v1)
    g.register_message_func(message_func, batchable=True)
    g.register_reduce_func(reduce_func, batchable=True)
46
    g.update_all()
Minjie Wang's avatar
Minjie Wang committed
47
48
    v3 = g.get_n_repr()
    check_eq(v2, v3)
49
50
51

if __name__ == '__main__':
    test_spmv_specialize()