test_specialization.py 7.34 KB
Newer Older
1
import torch as th
Minjie Wang's avatar
Minjie Wang committed
2
import numpy as np
3
4
import dgl
import dgl.function as fn
5
import utils as U
6

Minjie Wang's avatar
Minjie Wang committed
7
8
D = 5

9
def generate_graph():
10
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
11
    g.add_nodes(10)
12
13
14
15
16
17
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
18
19
20
    g.set_n_repr({'f1' : th.randn(10,), 'f2' : th.randn(10, D)})
    weights = th.randn(17,)
    g.set_e_repr({'e1': weights, 'e2': th.unsqueeze(weights, 1)})
21
22
    return g

23
24
def test_update_all():
    def _test(fld):
25
26
        def message_func(edges):
            return {'m' : edges.src[fld]}
27

28
29
30
        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m' : edges.src[fld] * edges.data['e1']}
31
            else:
32
                return {'m' : edges.src[fld] * edges.data['e2']}
33

34
35
        def reduce_func(nodes):
            return {fld : th.sum(nodes.mailbox['m'], 1)}
36

37
38
        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}
39
40
        g = generate_graph()
        # update all
41
        v1 = g.ndata[fld]
Minjie Wang's avatar
Minjie Wang committed
42
        g.update_all(fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out=fld), apply_func)
43
        v2 = g.ndata[fld]
44
        g.set_n_repr({fld : v1})
Minjie Wang's avatar
Minjie Wang committed
45
        g.update_all(message_func, reduce_func, apply_func)
46
        v3 = g.ndata[fld]
47
        assert U.allclose(v2, v3)
48
        # update all with edge weights
49
        v1 = g.ndata[fld]
Minjie Wang's avatar
Minjie Wang committed
50
51
        g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'),
                fn.sum(msg='m', out=fld), apply_func)
52
        v2 = g.ndata[fld]
53
        g.set_n_repr({fld : v1})
Minjie Wang's avatar
Minjie Wang committed
54
55
        g.update_all(fn.src_mul_edge(src=fld, edge='e2', out='m'),
                fn.sum(msg='m', out=fld), apply_func)
56
        v3 = g.ndata[fld]
57
        g.set_n_repr({fld : v1})
Minjie Wang's avatar
Minjie Wang committed
58
        g.update_all(message_func_edge, reduce_func, apply_func)
59
        v4 = g.ndata[fld]
60
61
        assert U.allclose(v2, v3)
        assert U.allclose(v3, v4)
62
63
64
65
66
67
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')

def test_send_and_recv():
Minjie Wang's avatar
Minjie Wang committed
68
69
    u = th.tensor([0, 0, 0, 3, 4, 9])
    v = th.tensor([1, 2, 3, 9, 9, 0])
70
    def _test(fld):
71
72
        def message_func(edges):
            return {'m' : edges.src[fld]}
73

74
75
76
        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m' : edges.src[fld] * edges.data['e1']}
77
            else:
78
                return {'m' : edges.src[fld] * edges.data['e2']}
79

80
81
        def reduce_func(nodes):
            return {fld : th.sum(nodes.mailbox['m'], 1)}
82

83
84
        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}
85
86
        g = generate_graph()
        # send and recv
87
88
        v1 = g.ndata[fld]
        g.send_and_recv((u, v), fn.copy_src(src=fld, out='m'),
Minjie Wang's avatar
Minjie Wang committed
89
                fn.sum(msg='m', out=fld), apply_func)
90
        v2 = g.ndata[fld]
91
        g.set_n_repr({fld : v1})
92
93
        g.send_and_recv((u, v), message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
94
        assert U.allclose(v2, v3)
95
        # send and recv with edge weights
96
97
        v1 = g.ndata[fld]
        g.send_and_recv((u, v), fn.src_mul_edge(src=fld, edge='e1', out='m'),
Minjie Wang's avatar
Minjie Wang committed
98
                fn.sum(msg='m', out=fld), apply_func)
99
        v2 = g.ndata[fld]
100
        g.set_n_repr({fld : v1})
101
        g.send_and_recv((u, v), fn.src_mul_edge(src=fld, edge='e2', out='m'),
Minjie Wang's avatar
Minjie Wang committed
102
                fn.sum(msg='m', out=fld), apply_func)
103
        v3 = g.ndata[fld]
104
        g.set_n_repr({fld : v1})
105
106
        g.send_and_recv((u, v), message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
107
108
        assert U.allclose(v2, v3)
        assert U.allclose(v3, v4)
109
110
111
112
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')
113

114
def test_update_all_multi_fn():
115
116
    def message_func(edges):
        return {'m2': edges.src['f2']}
117

118
119
    def message_func_edge(edges):
        return {'m2': edges.src['f2'] * edges.data['e2']}
120

121
122
    def reduce_func(nodes):
        return {'v2': th.sum(nodes.mailbox['m2'], 1)}
123
124

    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
125
    g.set_n_repr({'v1' : th.zeros((10,)), 'v2' : th.zeros((10,))})
126
127
128
    fld = 'f2'
    # update all, mix of builtin and UDF
    g.update_all([fn.copy_src(src=fld, out='m1'), message_func],
Minjie Wang's avatar
Minjie Wang committed
129
                 [fn.sum(msg='m1', out='v1'), reduce_func],
Minjie Wang's avatar
Minjie Wang committed
130
                 None)
131
132
    v1 = g.ndata['v1']
    v2 = g.ndata['v2']
133
    assert U.allclose(v1, v2)
134
135

    # run builtin with single message and reduce
Minjie Wang's avatar
Minjie Wang committed
136
    g.update_all(fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out='v1'), None)
137
    v1 = g.ndata['v1']
138
    assert U.allclose(v1, v2)
139

Minjie Wang's avatar
Minjie Wang committed
140
141
    # 1 message, 2 reduces
    g.update_all(fn.copy_src(src=fld, out='m'), [fn.sum(msg='m', out='v2'), fn.sum(msg='m', out='v3')], None)
142
143
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
144
145
    assert U.allclose(v1, v2)
    assert U.allclose(v1, v3)
146
147
148

    # update all with edge weights, 2 message, 3 reduces
    g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
Minjie Wang's avatar
Minjie Wang committed
149
                 [fn.sum(msg='m1', out='v1'), fn.sum(msg='m2', out='v2'), fn.sum(msg='m1', out='v3')],
Minjie Wang's avatar
Minjie Wang committed
150
                 None)
151
152
153
    v1 = g.ndata['v1']
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
154
155
    assert U.allclose(v1, v2)
    assert U.allclose(v1, v3)
156
157

    # run UDF with single message and reduce
Minjie Wang's avatar
Minjie Wang committed
158
    g.update_all(message_func_edge, reduce_func, None)
159
    v2 = g.ndata['v2']
160
    assert U.allclose(v1, v2)
161
162
163
164
165

def test_send_and_recv_multi_fn():
    u = th.tensor([0, 0, 0, 3, 4, 9])
    v = th.tensor([1, 2, 3, 9, 9, 0])

166
167
    def message_func(edges):
        return {'m2': edges.src['f2']}
168

169
170
    def message_func_edge(edges):
        return {'m2': edges.src['f2'] * edges.data['e2']}
171

172
173
    def reduce_func(nodes):
        return {'v2' : th.sum(nodes.mailbox['m2'], 1)}
174
175

    g = generate_graph()
Minjie Wang's avatar
Minjie Wang committed
176
177
    g.set_n_repr({'v1' : th.zeros((10, D)), 'v2' : th.zeros((10, D)),
        'v3' : th.zeros((10, D))})
178
179
180
    fld = 'f2'

    # send and recv, mix of builtin and UDF
181
    g.send_and_recv((u, v),
182
                    [fn.copy_src(src=fld, out='m1'), message_func],
Minjie Wang's avatar
Minjie Wang committed
183
                    [fn.sum(msg='m1', out='v1'), reduce_func],
Minjie Wang's avatar
Minjie Wang committed
184
                    None)
185
186
    v1 = g.ndata['v1']
    v2 = g.ndata['v2']
187
    assert U.allclose(v1, v2)
188
189

    # run builtin with single message and reduce
190
    g.send_and_recv((u, v), fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out='v1'),
Minjie Wang's avatar
Minjie Wang committed
191
                    None)
192
    v1 = g.ndata['v1']
193
    assert U.allclose(v1, v2)
194

Minjie Wang's avatar
Minjie Wang committed
195
    # 1 message, 2 reduces
196
    g.send_and_recv((u, v),
Minjie Wang's avatar
Minjie Wang committed
197
198
199
            fn.copy_src(src=fld, out='m'),
            [fn.sum(msg='m', out='v2'), fn.sum(msg='m', out='v3')],
            None)
200
201
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
202
203
    assert U.allclose(v1, v2)
    assert U.allclose(v1, v3)
204
205

    # send and recv with edge weights, 2 message, 3 reduces
206
    g.send_and_recv((u, v),
207
                    [fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
Minjie Wang's avatar
Minjie Wang committed
208
                    [fn.sum(msg='m1', out='v1'), fn.sum(msg='m2', out='v2'), fn.sum(msg='m1', out='v3')],
Minjie Wang's avatar
Minjie Wang committed
209
                    None)
210
211
212
    v1 = g.ndata['v1']
    v2 = g.ndata['v2']
    v3 = g.ndata['v3']
213
214
    assert U.allclose(v1, v2)
    assert U.allclose(v1, v3)
215
216

    # run UDF with single message and reduce
217
    g.send_and_recv((u, v), message_func_edge,
Minjie Wang's avatar
Minjie Wang committed
218
            reduce_func, None)
219
    v2 = g.ndata['v2']
220
    assert U.allclose(v1, v2)
221

222
if __name__ == '__main__':
223
    test_update_all()
224
    test_send_and_recv()
225
226
    test_update_all_multi_fn()
    test_send_and_recv_multi_fn()