test_specialization.py 7.35 KB
Newer Older
1
import torch as th
Minjie Wang's avatar
Minjie Wang committed
2
import numpy as np
3
4
import dgl
import dgl.function as fn
5

Minjie Wang's avatar
Minjie Wang committed
6
7
D = 5

8
def generate_graph():
9
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
10
    g.add_nodes(10)
11
12
13
14
15
16
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
17
18
19
    g.set_n_repr({'f1' : th.randn(10,), 'f2' : th.randn(10, D)})
    weights = th.randn(17,)
    g.set_e_repr({'e1': weights, 'e2': th.unsqueeze(weights, 1)})
20
21
    return g

22
23
def test_update_all():
    def _test(fld):
24
25
        def message_func(edges):
            return {'m' : edges.src[fld]}
26

27
28
29
        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m' : edges.src[fld] * edges.data['e1']}
30
            else:
31
                return {'m' : edges.src[fld] * edges.data['e2']}
32

33
34
        def reduce_func(nodes):
            return {fld : th.sum(nodes.mailbox['m'], 1)}
35

36
37
        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}
38
39
        g = generate_graph()
        # update all
40
        v1 = g.ndata[fld]
Minjie Wang's avatar
Minjie Wang committed
41
        g.update_all(fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out=fld), apply_func)
42
        v2 = g.ndata[fld]
43
        g.set_n_repr({fld : v1})
Minjie Wang's avatar
Minjie Wang committed
44
        g.update_all(message_func, reduce_func, apply_func)
45
        v3 = g.ndata[fld]
46
47
        assert th.allclose(v2, v3)
        # update all with edge weights
48
        v1 = g.ndata[fld]
Minjie Wang's avatar
Minjie Wang committed
49
50
        g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'),
                fn.sum(msg='m', out=fld), apply_func)
51
        v2 = g.ndata[fld]
52
        g.set_n_repr({fld : v1})
Minjie Wang's avatar
Minjie Wang committed
53
54
        g.update_all(fn.src_mul_edge(src=fld, edge='e2', out='m'),
                fn.sum(msg='m', out=fld), apply_func)
55
        v3 = g.ndata[fld]
56
        g.set_n_repr({fld : v1})
Minjie Wang's avatar
Minjie Wang committed
57
        g.update_all(message_func_edge, reduce_func, apply_func)
58
        v4 = g.ndata[fld]
59
60
61
62
63
64
65
66
        assert th.allclose(v2, v3)
        assert th.allclose(v3, v4)
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')

def test_send_and_recv():
Minjie Wang's avatar
Minjie Wang committed
67
68
    u = th.tensor([0, 0, 0, 3, 4, 9])
    v = th.tensor([1, 2, 3, 9, 9, 0])
69
    def _test(fld):
70
71
        def message_func(edges):
            return {'m' : edges.src[fld]}
72

73
74
75
        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m' : edges.src[fld] * edges.data['e1']}
76
            else:
77
                return {'m' : edges.src[fld] * edges.data['e2']}
78

79
80
        def reduce_func(nodes):
            return {fld : th.sum(nodes.mailbox['m'], 1)}
81

82
83
        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}
84
85
        g = generate_graph()
        # send and recv
86
87
        v1 = g.ndata[fld]
        g.send_and_recv((u, v), fn.copy_src(src=fld, out='m'),
Minjie Wang's avatar
Minjie Wang committed
88
                fn.sum(msg='m', out=fld), apply_func)
89
        v2 = g.ndata[fld]
90
        g.set_n_repr({fld : v1})
91
92
        g.send_and_recv((u, v), message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
93
94
        assert th.allclose(v2, v3)
        # send and recv with edge weights
95
96
        v1 = g.ndata[fld]
        g.send_and_recv((u, v), fn.src_mul_edge(src=fld, edge='e1', out='m'),
Minjie Wang's avatar
Minjie Wang committed
97
                fn.sum(msg='m', out=fld), apply_func)
98
        v2 = g.ndata[fld]
99
        g.set_n_repr({fld : v1})
100
        g.send_and_recv((u, v), fn.src_mul_edge(src=fld, edge='e2', out='m'),
Minjie Wang's avatar
Minjie Wang committed
101
                fn.sum(msg='m', out=fld), apply_func)
102
        v3 = g.ndata[fld]
103
        g.set_n_repr({fld : v1})
104
105
        g.send_and_recv((u, v), message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
106
107
108
109
110
111
        assert th.allclose(v2, v3)
        assert th.allclose(v3, v4)
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')
112

113
def test_update_all_multi_fn():
114
115
    def message_func(edges):
        return {'m2': edges.src['f2']}
116

117
118
    def message_func_edge(edges):
        return {'m2': edges.src['f2'] * edges.data['e2']}
119

120
121
    def reduce_func(nodes):
        return {'v2': th.sum(nodes.mailbox['m2'], 1)}
122
123

    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
124
    g.set_n_repr({'v1' : th.zeros((10,)), 'v2' : th.zeros((10,))})
125
126
127
    fld = 'f2'
    # update all, mix of builtin and UDF
    g.update_all([fn.copy_src(src=fld, out='m1'), message_func],
Minjie Wang's avatar
Minjie Wang committed
128
                 [fn.sum(msg='m1', out='v1'), reduce_func],
Minjie Wang's avatar
Minjie Wang committed
129
                 None)
130
131
    v1 = g.ndata['v1']
    v2 = g.ndata['v2']
132
133
134
    assert th.allclose(v1, v2)

    # run builtin with single message and reduce
Minjie Wang's avatar
Minjie Wang committed
135
    g.update_all(fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out='v1'), None)
136
    v1 = g.ndata['v1']
137
138
    assert th.allclose(v1, v2)

Minjie Wang's avatar
Minjie Wang committed
139
140
    # 1 message, 2 reduces
    g.update_all(fn.copy_src(src=fld, out='m'), [fn.sum(msg='m', out='v2'), fn.sum(msg='m', out='v3')], None)
141
142
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
143
144
145
146
147
    assert th.allclose(v1, v2)
    assert th.allclose(v1, v3)

    # update all with edge weights, 2 message, 3 reduces
    g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
Minjie Wang's avatar
Minjie Wang committed
148
                 [fn.sum(msg='m1', out='v1'), fn.sum(msg='m2', out='v2'), fn.sum(msg='m1', out='v3')],
Minjie Wang's avatar
Minjie Wang committed
149
                 None)
150
151
152
    v1 = g.ndata['v1']
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
153
154
155
156
    assert th.allclose(v1, v2)
    assert th.allclose(v1, v3)

    # run UDF with single message and reduce
Minjie Wang's avatar
Minjie Wang committed
157
    g.update_all(message_func_edge, reduce_func, None)
158
    v2 = g.ndata['v2']
159
160
161
162
163
164
    assert th.allclose(v1, v2)

def test_send_and_recv_multi_fn():
    u = th.tensor([0, 0, 0, 3, 4, 9])
    v = th.tensor([1, 2, 3, 9, 9, 0])

165
166
    def message_func(edges):
        return {'m2': edges.src['f2']}
167

168
169
    def message_func_edge(edges):
        return {'m2': edges.src['f2'] * edges.data['e2']}
170

171
172
    def reduce_func(nodes):
        return {'v2' : th.sum(nodes.mailbox['m2'], 1)}
173
174

    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
175
176
    g.set_n_repr({'v1' : th.zeros((10, D)), 'v2' : th.zeros((10, D)),
        'v3' : th.zeros((10, D))})
177
178
179
    fld = 'f2'

    # send and recv, mix of builtin and UDF
180
    g.send_and_recv((u, v),
181
                    [fn.copy_src(src=fld, out='m1'), message_func],
Minjie Wang's avatar
Minjie Wang committed
182
                    [fn.sum(msg='m1', out='v1'), reduce_func],
Minjie Wang's avatar
Minjie Wang committed
183
                    None)
184
185
    v1 = g.ndata['v1']
    v2 = g.ndata['v2']
186
187
188
    assert th.allclose(v1, v2)

    # run builtin with single message and reduce
189
    g.send_and_recv((u, v), fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out='v1'),
Minjie Wang's avatar
Minjie Wang committed
190
                    None)
191
    v1 = g.ndata['v1']
192
193
    assert th.allclose(v1, v2)

Minjie Wang's avatar
Minjie Wang committed
194
    # 1 message, 2 reduces
195
    g.send_and_recv((u, v),
Minjie Wang's avatar
Minjie Wang committed
196
197
198
            fn.copy_src(src=fld, out='m'),
            [fn.sum(msg='m', out='v2'), fn.sum(msg='m', out='v3')],
            None)
199
200
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
201
202
203
204
    assert th.allclose(v1, v2)
    assert th.allclose(v1, v3)

    # send and recv with edge weights, 2 message, 3 reduces
205
    g.send_and_recv((u, v),
206
                    [fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
Minjie Wang's avatar
Minjie Wang committed
207
                    [fn.sum(msg='m1', out='v1'), fn.sum(msg='m2', out='v2'), fn.sum(msg='m1', out='v3')],
Minjie Wang's avatar
Minjie Wang committed
208
                    None)
209
210
211
    v1 = g.ndata['v1']
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
212
213
214
215
    assert th.allclose(v1, v2)
    assert th.allclose(v1, v3)

    # run UDF with single message and reduce
216
    g.send_and_recv((u, v), message_func_edge,
Minjie Wang's avatar
Minjie Wang committed
217
            reduce_func, None)
218
    v2 = g.ndata['v2']
219
220
    assert th.allclose(v1, v2)

221
if __name__ == '__main__':
222
    test_update_all()
223
    test_send_and_recv()
224
225
    test_update_all_multi_fn()
    test_send_and_recv_multi_fn()