test_specialization.py 3.79 KB
Newer Older
1
import torch as th
Minjie Wang's avatar
Minjie Wang committed
2
import numpy as np
3
4
import dgl
import dgl.function as fn
5

Minjie Wang's avatar
Minjie Wang committed
6
7
D = 5

8
def generate_graph():
9
    g = dgl.DGLGraph()
10
11
12
13
14
15
16
17
    for i in range(10):
        g.add_node(i) # 10 nodes.
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
18
19
20
    g.set_n_repr({'f1' : th.randn(10,), 'f2' : th.randn(10, D)})
    weights = th.randn(17,)
    g.set_e_repr({'e1': weights, 'e2': th.unsqueeze(weights, 1)})
21
22
    return g

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def test_update_all():
    def _test(fld):
        def message_func(hu, edge):
            return hu[fld]

        def message_func_edge(hu, edge):
            if len(hu[fld].shape) == 1:
                return hu[fld] * edge['e1']
            else:
                return hu[fld] * edge['e2']

        def reduce_func(hv, msgs):
            return {fld : th.sum(msgs, 1)}

        def apply_func(hu):
            return {fld : 2 * hu[fld]}
        g = generate_graph()
        # update all
        v1 = g.get_n_repr()[fld]
        g.update_all(fn.copy_src(src=fld), fn.sum(out=fld), apply_func, batchable=True)
        v2 = g.get_n_repr()[fld]
        g.set_n_repr({fld : v1})
        g.update_all(message_func, reduce_func, apply_func, batchable=True)
        v3 = g.get_n_repr()[fld]
        assert th.allclose(v2, v3)
        # update all with edge weights
        v1 = g.get_n_repr()[fld]
        g.update_all(fn.src_mul_edge(src=fld, edge='e1'),
                fn.sum(out=fld), apply_func, batchable=True)
        v2 = g.get_n_repr()[fld]
        g.set_n_repr({fld : v1})
        g.update_all(fn.src_mul_edge(src=fld, edge='e2'),
                fn.sum(out=fld), apply_func, batchable=True)
        v3 = g.get_n_repr()[fld]
        g.set_n_repr({fld : v1})
        g.update_all(message_func_edge, reduce_func, apply_func, batchable=True)
        v4 = g.get_n_repr()[fld]
        assert th.allclose(v2, v3)
        assert th.allclose(v3, v4)
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')

def test_send_and_recv():
Minjie Wang's avatar
Minjie Wang committed
68
69
    u = th.tensor([0, 0, 0, 3, 4, 9])
    v = th.tensor([1, 2, 3, 9, 9, 0])
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    def _test(fld):
        def message_func(hu, edge):
            return hu[fld]

        def message_func_edge(hu, edge):
            if len(hu[fld].shape) == 1:
                return hu[fld] * edge['e1']
            else:
                return hu[fld] * edge['e2']

        def reduce_func(hv, msgs):
            return {fld : th.sum(msgs, 1)}

        def apply_func(hu):
            return {fld : 2 * hu[fld]}
        g = generate_graph()
        # send and recv
        v1 = g.get_n_repr()[fld]
        g.send_and_recv(u, v, fn.copy_src(src=fld),
                fn.sum(out=fld), apply_func, batchable=True)
        v2 = g.get_n_repr()[fld]
        g.set_n_repr({fld : v1})
        g.send_and_recv(u, v, message_func,
                reduce_func, apply_func, batchable=True)
        v3 = g.get_n_repr()[fld]
        assert th.allclose(v2, v3)
        # send and recv with edge weights
        v1 = g.get_n_repr()[fld]
        g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e1'),
                fn.sum(out=fld), apply_func, batchable=True)
        v2 = g.get_n_repr()[fld]
        g.set_n_repr({fld : v1})
        g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e2'),
                fn.sum(out=fld), apply_func, batchable=True)
        v3 = g.get_n_repr()[fld]
        g.set_n_repr({fld : v1})
        g.send_and_recv(u, v, message_func_edge,
                reduce_func, apply_func, batchable=True)
        v4 = g.get_n_repr()[fld]
        assert th.allclose(v2, v3)
        assert th.allclose(v3, v4)
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')
115
116

if __name__ == '__main__':
117
118
    #test_update_all()
    test_send_and_recv()