"vscode:/vscode.git/clone" did not exist on "1bf0d8d4ba50b3ce06456f0757f15b73bbb65250"
test_heterograph-specialization.py 10.3 KB
Newer Older
1
2
import backend as F

3
4
import dgl
import dgl.function as fn
5
6
import numpy as np
import scipy.sparse as sp
7
from utils import parametrize_idtype
8

Minjie Wang's avatar
Minjie Wang committed
9
10
D = 5

11

12
def generate_graph(idtype):
13
    g = dgl.DGLGraph()
14
    g = g.astype(idtype).to(F.ctx())
Minjie Wang's avatar
Minjie Wang committed
15
    g.add_nodes(10)
16
17
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
18
19
        g.add_edges(0, i)
        g.add_edges(i, 9)
20
    # add a back flow from 9 to 0
21
    g.add_edges(9, 0)
22
    g.ndata.update({"f1": F.randn((10,)), "f2": F.randn((10, D))})
23
    weights = F.randn((17,))
24
    g.edata.update({"e1": weights, "e2": F.unsqueeze(weights, 1)})
25
26
    return g

27

nv-dlasalle's avatar
nv-dlasalle committed
28
@parametrize_idtype
29
def test_v2v_update_all(idtype):
30
    def _test(fld):
31
        def message_func(edges):
32
            return {"m": edges.src[fld]}
33

34
35
        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
36
                return {"m": edges.src[fld] * edges.data["e1"]}
37
            else:
38
                return {"m": edges.src[fld] * edges.data["e2"]}
39

40
        def reduce_func(nodes):
41
            return {fld: F.sum(nodes.mailbox["m"], 1)}
42

43
        def apply_func(nodes):
44
45
            return {fld: 2 * nodes.data[fld]}

46
        g = generate_graph(idtype)
47
        # update all
48
        v1 = g.ndata[fld]
49
50
51
        g.update_all(
            fn.copy_u(u=fld, out="m"), fn.sum(msg="m", out=fld), apply_func
        )
52
        v2 = g.ndata[fld]
53
        g.ndata.update({fld: v1})
Minjie Wang's avatar
Minjie Wang committed
54
        g.update_all(message_func, reduce_func, apply_func)
55
        v3 = g.ndata[fld]
56
        assert F.allclose(v2, v3)
57
        # update all with edge weights
58
        v1 = g.ndata[fld]
59
60
61
        g.update_all(
            fn.u_mul_e(fld, "e1", "m"), fn.sum(msg="m", out=fld), apply_func
        )
62
        v2 = g.ndata[fld]
63
        g.ndata.update({fld: v1})
Minjie Wang's avatar
Minjie Wang committed
64
        g.update_all(message_func_edge, reduce_func, apply_func)
65
        v4 = g.ndata[fld]
66
        assert F.allclose(v2, v4)
67

68
    # test 1d node features
69
    _test("f1")
70
    # test 2d node features
71
72
    _test("f2")

73

nv-dlasalle's avatar
nv-dlasalle committed
74
@parametrize_idtype
75
76
77
def test_v2v_snr(idtype):
    u = F.tensor([0, 0, 0, 3, 4, 9], idtype)
    v = F.tensor([1, 2, 3, 9, 9, 0], idtype)
78

79
    def _test(fld):
80
        def message_func(edges):
81
            return {"m": edges.src[fld]}
82

83
84
        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
85
                return {"m": edges.src[fld] * edges.data["e1"]}
86
            else:
87
                return {"m": edges.src[fld] * edges.data["e2"]}
88

89
        def reduce_func(nodes):
90
            return {fld: F.sum(nodes.mailbox["m"], 1)}
91

92
        def apply_func(nodes):
93
94
            return {fld: 2 * nodes.data[fld]}

95
        g = generate_graph(idtype)
96
        # send and recv
97
        v1 = g.ndata[fld]
98
99
100
101
102
103
        g.send_and_recv(
            (u, v),
            fn.copy_u(u=fld, out="m"),
            fn.sum(msg="m", out=fld),
            apply_func,
        )
104
        v2 = g.ndata[fld]
105
        g.ndata.update({fld: v1})
106
107
        g.send_and_recv((u, v), message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
108
        assert F.allclose(v2, v3)
109
        # send and recv with edge weights
110
        v1 = g.ndata[fld]
111
112
113
114
115
116
        g.send_and_recv(
            (u, v),
            fn.u_mul_e(fld, "e1", "m"),
            fn.sum(msg="m", out=fld),
            apply_func,
        )
117
        v2 = g.ndata[fld]
118
        g.ndata.update({fld: v1})
119
120
        g.send_and_recv((u, v), message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
121
        assert F.allclose(v2, v4)
122

123
    # test 1d node features
124
    _test("f1")
125
    # test 2d node features
126
    _test("f2")
127

128

nv-dlasalle's avatar
nv-dlasalle committed
129
@parametrize_idtype
130
131
def test_v2v_pull(idtype):
    nodes = F.tensor([1, 2, 3, 9], idtype)
132

133
134
    def _test(fld):
        def message_func(edges):
135
            return {"m": edges.src[fld]}
136
137
138

        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
139
                return {"m": edges.src[fld] * edges.data["e1"]}
140
            else:
141
                return {"m": edges.src[fld] * edges.data["e2"]}
142
143

        def reduce_func(nodes):
144
            return {fld: F.sum(nodes.mailbox["m"], 1)}
145
146

        def apply_func(nodes):
147
148
            return {fld: 2 * nodes.data[fld]}

149
        g = generate_graph(idtype)
150
151
        # send and recv
        v1 = g.ndata[fld]
152
153
154
155
156
157
        g.pull(
            nodes,
            fn.copy_u(u=fld, out="m"),
            fn.sum(msg="m", out=fld),
            apply_func,
        )
158
159
160
161
        v2 = g.ndata[fld]
        g.ndata[fld] = v1
        g.pull(nodes, message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
162
        assert F.allclose(v2, v3)
163
164
        # send and recv with edge weights
        v1 = g.ndata[fld]
165
166
167
168
169
170
        g.pull(
            nodes,
            fn.u_mul_e(fld, "e1", "m"),
            fn.sum(msg="m", out=fld),
            apply_func,
        )
171
172
173
174
        v2 = g.ndata[fld]
        g.ndata[fld] = v1
        g.pull(nodes, message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
175
        assert F.allclose(v2, v4)
176

177
    # test 1d node features
178
    _test("f1")
179
    # test 2d node features
180
181
    _test("f2")

182

nv-dlasalle's avatar
nv-dlasalle committed
183
@parametrize_idtype
184
def test_update_all_multi_fallback(idtype):
185
186
    # create a graph with zero in degree nodes
    g = dgl.DGLGraph()
187
    g = g.astype(idtype).to(F.ctx())
188
189
    g.add_nodes(10)
    for i in range(1, 9):
190
191
        g.add_edges(0, i)
        g.add_edges(i, 9)
192
193
194
195
    g.ndata["h"] = F.randn((10, D))
    g.edata["w1"] = F.randn((16,))
    g.edata["w2"] = F.randn((16, D))

196
    def _mfunc_hxw1(edges):
197
198
        return {"m1": edges.src["h"] * F.unsqueeze(edges.data["w1"], 1)}

199
    def _mfunc_hxw2(edges):
200
201
        return {"m2": edges.src["h"] * edges.data["w2"]}

202
    def _rfunc_m1(nodes):
203
204
        return {"o1": F.sum(nodes.mailbox["m1"], 1)}

205
    def _rfunc_m2(nodes):
206
207
        return {"o2": F.sum(nodes.mailbox["m2"], 1)}

208
    def _rfunc_m1max(nodes):
209
210
        return {"o3": F.max(nodes.mailbox["m1"], 1)}

211
212
213
    def _afunc(nodes):
        ret = {}
        for k, v in nodes.data.items():
214
            if k.startswith("o"):
215
216
                ret[k] = 2 * v
        return ret
217

218
219
    # compute ground truth
    g.update_all(_mfunc_hxw1, _rfunc_m1, _afunc)
220
    o1 = g.ndata.pop("o1")
221
    g.update_all(_mfunc_hxw2, _rfunc_m2, _afunc)
222
    o2 = g.ndata.pop("o2")
223
    g.update_all(_mfunc_hxw1, _rfunc_m1max, _afunc)
224
    o3 = g.ndata.pop("o3")
225
    # v2v spmv
226
227
228
229
    g.update_all(
        fn.u_mul_e("h", "w1", "m1"), fn.sum(msg="m1", out="o1"), _afunc
    )
    assert F.allclose(o1, g.ndata.pop("o1"))
230
    # v2v fallback to e2v
231
232
233
234
235
    g.update_all(
        fn.u_mul_e("h", "w2", "m2"), fn.sum(msg="m2", out="o2"), _afunc
    )
    assert F.allclose(o2, g.ndata.pop("o2"))

236

nv-dlasalle's avatar
nv-dlasalle committed
237
@parametrize_idtype
238
def test_pull_multi_fallback(idtype):
239
240
    # create a graph with zero in degree nodes
    g = dgl.DGLGraph()
241
    g = g.astype(idtype).to(F.ctx())
242
243
    g.add_nodes(10)
    for i in range(1, 9):
244
245
        g.add_edges(0, i)
        g.add_edges(i, 9)
246
247
248
249
    g.ndata["h"] = F.randn((10, D))
    g.edata["w1"] = F.randn((16,))
    g.edata["w2"] = F.randn((16, D))

250
    def _mfunc_hxw1(edges):
251
252
        return {"m1": edges.src["h"] * F.unsqueeze(edges.data["w1"], 1)}

253
    def _mfunc_hxw2(edges):
254
255
        return {"m2": edges.src["h"] * edges.data["w2"]}

256
    def _rfunc_m1(nodes):
257
258
        return {"o1": F.sum(nodes.mailbox["m1"], 1)}

259
    def _rfunc_m2(nodes):
260
261
        return {"o2": F.sum(nodes.mailbox["m2"], 1)}

262
    def _rfunc_m1max(nodes):
263
264
        return {"o3": F.max(nodes.mailbox["m1"], 1)}

265
266
267
    def _afunc(nodes):
        ret = {}
        for k, v in nodes.data.items():
268
            if k.startswith("o"):
269
270
                ret[k] = 2 * v
        return ret
271

272
273
274
275
    # nodes to pull
    def _pull_nodes(nodes):
        # compute ground truth
        g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc)
276
        o1 = g.ndata.pop("o1")
277
        g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc)
278
        o2 = g.ndata.pop("o2")
279
        g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc)
280
        o3 = g.ndata.pop("o3")
281
        # v2v spmv
282
283
284
285
286
287
288
        g.pull(
            nodes,
            fn.u_mul_e("h", "w1", "m1"),
            fn.sum(msg="m1", out="o1"),
            _afunc,
        )
        assert F.allclose(o1, g.ndata.pop("o1"))
289
        # v2v fallback to e2v
290
291
292
293
294
295
296
297
        g.pull(
            nodes,
            fn.u_mul_e("h", "w2", "m2"),
            fn.sum(msg="m2", out="o2"),
            _afunc,
        )
        assert F.allclose(o2, g.ndata.pop("o2"))

298
299
300
301
302
303
304
    # test#1: non-0deg nodes
    nodes = [1, 2, 9]
    _pull_nodes(nodes)
    # test#2: 0deg nodes + non-0deg nodes
    nodes = [0, 1, 2, 9]
    _pull_nodes(nodes)

305

nv-dlasalle's avatar
nv-dlasalle committed
306
@parametrize_idtype
307
def test_spmv_3d_feat(idtype):
308
    def src_mul_edge_udf(edges):
309
310
311
312
        return {
            "sum": edges.src["h"]
            * F.unsqueeze(F.unsqueeze(edges.data["h"], 1), 1)
        }
313
314

    def sum_udf(nodes):
315
        return {"h": F.sum(nodes.mailbox["sum"], 1)}
316
317
318
319
320

    n = 100
    p = 0.1
    a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n))
    g = dgl.DGLGraph(a)
321
    g = g.astype(idtype).to(F.ctx())
322
323
324
    m = g.number_of_edges()

    # test#1: v2v with adj data
325
326
    h = F.randn((n, 5, 5))
    e = F.randn((m,))
327

328
329
330
331
332
333
    g.ndata["h"] = h
    g.edata["h"] = e
    g.update_all(
        message_func=fn.u_mul_e("h", "h", "sum"), reduce_func=fn.sum("sum", "h")
    )  # 1
    ans = g.ndata["h"]
334

335
336
337
338
339
340
    g.ndata["h"] = h
    g.edata["h"] = e
    g.update_all(
        message_func=src_mul_edge_udf, reduce_func=fn.sum("sum", "h")
    )  # 2
    assert F.allclose(g.ndata["h"], ans)
341

342
343
344
345
    g.ndata["h"] = h
    g.edata["h"] = e
    g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf)  # 3
    assert F.allclose(g.ndata["h"], ans)
346
347
348

    # test#2: e2v
    def src_mul_edge_udf(edges):
349
        return {"sum": edges.src["h"] * edges.data["h"]}
350

351
352
    h = F.randn((n, 5, 5))
    e = F.randn((m, 5, 5))
353

354
355
356
357
358
359
360
361
362
363
364
365
366
    g.ndata["h"] = h
    g.edata["h"] = e
    g.update_all(
        message_func=fn.u_mul_e("h", "h", "sum"), reduce_func=fn.sum("sum", "h")
    )  # 1
    ans = g.ndata["h"]

    g.ndata["h"] = h
    g.edata["h"] = e
    g.update_all(
        message_func=src_mul_edge_udf, reduce_func=fn.sum("sum", "h")
    )  # 2
    assert F.allclose(g.ndata["h"], ans)
367

368
369
370
371
    g.ndata["h"] = h
    g.edata["h"] = e
    g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf)  # 3
    assert F.allclose(g.ndata["h"], ans)
372
373


374
if __name__ == "__main__":
375
376
    test_v2v_update_all()
    test_v2v_snr()
377
    test_v2v_pull()
378
379
380
381
382
    test_v2v_update_all_multi_fn()
    test_v2v_snr_multi_fn()
    test_e2v_update_all_multi_fn()
    test_e2v_snr_multi_fn()
    test_e2v_recv_multi_fn()
383
384
    test_update_all_multi_fallback()
    test_pull_multi_fallback()
385
    test_spmv_3d_feat()