test_heterograph-kernel.py 13.9 KB
Newer Older
1
2
3
4
from itertools import product

import backend as F

5
6
7
import dgl
import dgl.function as fn
import networkx as nx
8
import numpy as np
9
import pytest
10
from utils import get_cases, parametrize_idtype
11

12
13

def udf_copy_src(edges):
14
15
    return {"m": edges.src["u"]}

16
17

def udf_copy_edge(edges):
18
19
    return {"m": edges.data["e"]}

20

21
def udf_mean(nodes):
22
23
    return {"r2": F.mean(nodes.mailbox["m"], 1)}

24
25

def udf_sum(nodes):
26
27
    return {"r2": F.sum(nodes.mailbox["m"], 1)}

28
29

def udf_max(nodes):
30
    return {"r2": F.max(nodes.mailbox["m"], 1)}
31
32
33
34
35


D1 = 5
D2 = 3
D3 = 4
36
37
38
39
D4 = 10  # NOTE(xiang): used to dot feature vector
builtin = {"sum": fn.sum, "max": fn.max, "mean": fn.mean}
udf_reduce = {"sum": udf_sum, "max": udf_max, "mean": udf_mean}
fill_value = {"sum": 0, "max": float("-inf")}
40
41


42
def generate_feature(g, broadcast="none", binary_op="none"):
43
44
45
    """Create graph with src, edge, dst feature. broadcast can be 'u',
    'e', 'v', 'none'
    """
46
    np.random.seed(31)
47
48
    nv = g.number_of_nodes()
    ne = g.number_of_edges()
49
50
    if binary_op == "dot":
        if broadcast == "e":
51
52
53
            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D2, 1, D4)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
54
        elif broadcast == "u":
55
56
57
            u = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1, D4)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
58
        elif broadcast == "v":
59
60
61
62
63
64
65
            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1, D4)))
        else:
            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3, D4)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3, D4)))
66
    else:
67
        if broadcast == "e":
68
69
70
            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D2, 1)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
71
        elif broadcast == "u":
72
73
74
            u = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
75
        elif broadcast == "v":
76
77
78
79
80
81
82
            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D2, 1)))
        else:
            u = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
            e = F.tensor(np.random.uniform(-1, 1, (ne, D1, D2, D3)))
            v = F.tensor(np.random.uniform(-1, 1, (nv, D1, D2, D3)))
83
84
85
86
87
    return (
        F.astype(u, F.float32),
        F.astype(v, F.float32),
        F.astype(e, F.float32),
    )
88
89
90


def test_copy_src_reduce():
91
    def _test(red, partial):
92
        g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1))
93
94
95
        # NOTE(zihao): add self-loop to avoid zero-degree nodes.
        # https://github.com/dmlc/dgl/issues/761
        g.add_edges(g.nodes(), g.nodes())
96
        g = g.to(F.ctx())
97
        hu, hv, he = generate_feature(g, "none", "none")
98
        if partial:
99
            nid = F.tensor(list(range(0, 100, 2)), g.idtype)
100

101
102
103
        g.ndata["u"] = F.attach_grad(F.clone(hu))
        g.ndata["v"] = F.attach_grad(F.clone(hv))
        g.edata["e"] = F.attach_grad(F.clone(he))
104
105

        with F.record_grad():
106
            if partial:
107
108
109
110
111
                g.pull(
                    nid,
                    fn.copy_u(u="u", out="m"),
                    builtin[red](msg="m", out="r1"),
                )
112
            else:
113
114
115
116
                g.update_all(
                    fn.copy_u(u="u", out="m"), builtin[red](msg="m", out="r1")
                )
            r1 = g.ndata["r1"]
VoVAllen's avatar
VoVAllen committed
117
            F.backward(F.reduce_sum(r1))
118
            n_grad1 = F.grad(g.ndata["u"])
119
120

        # reset grad
121
122
123
        g.ndata["u"] = F.attach_grad(F.clone(hu))
        g.ndata["v"] = F.attach_grad(F.clone(hv))
        g.edata["e"] = F.attach_grad(F.clone(he))
124
125

        with F.record_grad():
126
127
128
129
            if partial:
                g.pull(nid, udf_copy_src, udf_reduce[red])
            else:
                g.update_all(udf_copy_src, udf_reduce[red])
130
            r2 = g.ndata["r2"]
VoVAllen's avatar
VoVAllen committed
131
            F.backward(F.reduce_sum(r2))
132
            n_grad2 = F.grad(g.ndata["u"])
133

134
        def _print_error(a, b):
135
136
137
138
            print("ERROR: Test copy_src_{} partial: {}".format(red, partial))
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
139
                if not np.allclose(x, y):
140
                    print("@{} {} v.s. {}".format(i, x, y))
141
142
143

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
144
        assert F.allclose(r1, r2)
145
        if not F.allclose(n_grad1, n_grad2):
146
            print("node grad")
147
            _print_error(n_grad1, n_grad2)
148
        assert F.allclose(n_grad1, n_grad2)
149

150
151
152
153
154
155
    _test("sum", False)
    _test("max", False)
    _test("mean", False)
    _test("sum", True)
    _test("max", True)
    _test("mean", True)
156
157


158
def test_copy_edge_reduce():
159
    def _test(red, partial):
160
        g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1))
161
162
        # NOTE(zihao): add self-loop to avoid zero-degree nodes.
        g.add_edges(g.nodes(), g.nodes())
163
        g = g.to(F.ctx())
164
        hu, hv, he = generate_feature(g, "none", "none")
165
        if partial:
166
            nid = F.tensor(list(range(0, 100, 2)), g.idtype)
167

168
169
170
        g.ndata["u"] = F.attach_grad(F.clone(hu))
        g.ndata["v"] = F.attach_grad(F.clone(hv))
        g.edata["e"] = F.attach_grad(F.clone(he))
171
172

        with F.record_grad():
173
            if partial:
174
175
176
177
178
                g.pull(
                    nid,
                    fn.copy_e(e="e", out="m"),
                    builtin[red](msg="m", out="r1"),
                )
179
            else:
180
181
182
183
                g.update_all(
                    fn.copy_e(e="e", out="m"), builtin[red](msg="m", out="r1")
                )
            r1 = g.ndata["r1"]
VoVAllen's avatar
VoVAllen committed
184
            F.backward(F.reduce_sum(r1))
185
            e_grad1 = F.grad(g.edata["e"])
186
187

        # reset grad
188
189
190
        g.ndata["u"] = F.attach_grad(F.clone(hu))
        g.ndata["v"] = F.attach_grad(F.clone(hv))
        g.edata["e"] = F.attach_grad(F.clone(he))
191
192

        with F.record_grad():
193
194
195
196
            if partial:
                g.pull(nid, udf_copy_edge, udf_reduce[red])
            else:
                g.update_all(udf_copy_edge, udf_reduce[red])
197
            r2 = g.ndata["r2"]
VoVAllen's avatar
VoVAllen committed
198
            F.backward(F.reduce_sum(r2))
199
            e_grad2 = F.grad(g.edata["e"])
200

201
        def _print_error(a, b):
202
            print("ERROR: Test copy_edge_{} partial: {}".format(red, partial))
203
            return
204
205
206
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
207
                if not np.allclose(x, y):
208
                    print("@{} {} v.s. {}".format(i, x, y))
209
210
211

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
212
        assert F.allclose(r1, r2)
213
        if not F.allclose(e_grad1, e_grad2):
214
            print("edge gradient")
215
            _print_error(e_grad1, e_grad2)
216
        assert F.allclose(e_grad1, e_grad2)
217

218
219
220
221
222
223
    _test("sum", False)
    _test("max", False)
    _test("mean", False)
    _test("sum", True)
    _test("max", True)
    _test("mean", True)
224
225
226


def test_all_binary_builtins():
227
    def _test(g, lhs, rhs, binary_op, reducer, partial, nid, broadcast="none"):
228
        # initialize node/edge features with uniform(-1, 1)
229
        hu, hv, he = generate_feature(g, broadcast, binary_op)
230
        if binary_op == "div":
231
232
233
234
            # op = div
            # lhs range: [-1, 1]
            # rhs range: [1, 2]
            # result range: [-1, 1]
235
            if rhs == "u":
236
                hu = (hu + 3) / 2
237
            elif rhs == "v":
238
                hv = (hv + 3) / 2
239
            elif rhs == "e":
240
241
                he = (he + 3) / 2

242
        if binary_op == "add" or binary_op == "sub":
243
244
245
246
247
248
249
250
            # op = add, sub
            # lhs range: [-1/2, 1/2]
            # rhs range: [-1/2, 1/2]
            # result range: [-1, 1]
            hu = hu / 2
            hv = hv / 2
            he = he / 2

251
252
253
        g.ndata["u"] = F.attach_grad(F.clone(hu))
        g.ndata["v"] = F.attach_grad(F.clone(hv))
        g.edata["e"] = F.attach_grad(F.clone(he))
254
255
256
257
258
259
260
261
262
263
264
265
266
267

        builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
        builtin_msg = getattr(fn, builtin_msg_name)
        builtin_red = getattr(fn, reducer)

        def target_feature_switch(g, target):
            if target == "u":
                return g.ndata["u"]
            elif target == "v":
                return g.ndata["v"]
            else:
                return g.edata["e"]

        with F.record_grad():
268
            if partial:
269
                g.pull(nid, builtin_msg(lhs, rhs, "m"), builtin_red("m", "r1"))
270
            else:
271
272
                g.update_all(builtin_msg(lhs, rhs, "m"), builtin_red("m", "r1"))
            r1 = g.ndata.pop("r1")
VoVAllen's avatar
VoVAllen committed
273
            F.backward(F.reduce_sum(r1))
274
275
276
277
            lhs_grad_1 = F.grad(target_feature_switch(g, lhs))
            rhs_grad_1 = F.grad(target_feature_switch(g, rhs))

        # reset grad
278
279
280
        g.ndata["u"] = F.attach_grad(F.clone(hu))
        g.ndata["v"] = F.attach_grad(F.clone(hv))
        g.edata["e"] = F.attach_grad(F.clone(he))
281
282
283
284
285
286
287
288
289

        def target_switch(edges, target):
            if target == "u":
                return edges.src
            elif target == "v":
                return edges.dst
            elif target == "e":
                return edges.data
            else:
290
                assert 0, "Unknown target {}".format(target)
291
292
293

        def mfunc(edges):
            op = getattr(F, binary_op)
294
295
296
297
298
299
300
301
302
            lhs_data = target_switch(edges, lhs)[lhs]
            rhs_data = target_switch(edges, rhs)[rhs]
            # NOTE(zihao): we need to do batched broadcast
            # e.g. (68, 3, 1) op (68, 5, 3, 4)
            while F.ndim(lhs_data) < F.ndim(rhs_data):
                lhs_data = F.unsqueeze(lhs_data, 1)
            while F.ndim(rhs_data) < F.ndim(lhs_data):
                rhs_data = F.unsqueeze(rhs_data, 1)
            return {"m": op(lhs_data, rhs_data)}
303
304
305

        def rfunc(nodes):
            op = getattr(F, reducer)
306
            return {"r2": op(nodes.mailbox["m"], 1)}
307
308

        with F.record_grad():
309
310
311
312
            if partial:
                g.pull(nid, mfunc, rfunc)
            else:
                g.update_all(mfunc, rfunc)
313
314
            r2 = g.ndata.pop("r2")
            F.backward(F.reduce_sum(r2), F.tensor([1.0]))
315
316
317
            lhs_grad_2 = F.grad(target_feature_switch(g, lhs))
            rhs_grad_2 = F.grad(target_feature_switch(g, rhs))

318
319
        rtol = 1e-4
        atol = 1e-4
320

321
        def _print_error(a, b):
322
323
324
325
326
            print(
                "ERROR: Test {}_{}_{}_{} broadcast: {} partial: {}".format(
                    lhs, binary_op, rhs, reducer, broadcast, partial
                )
            )
327
            return
328
            if lhs == "u":
329
                lhs_data = hu
330
            elif lhs == "v":
331
                lhs_data = hv
332
            elif lhs == "e":
333
334
                lhs_data = he

335
            if rhs == "u":
336
                rhs_data = hu
337
            elif rhs == "v":
338
                rhs_data = hv
339
            elif rhs == "e":
340
341
342
                rhs_data = he
            print("lhs", F.asnumpy(lhs_data).tolist())
            print("rhs", F.asnumpy(rhs_data).tolist())
343
344
345
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
346
                if not np.allclose(x, y, rtol, atol):
347
                    print("@{} {} v.s. {}".format(i, x, y))
348

349
        if not F.allclose(r1, r2, rtol, atol):
350
            _print_error(r1, r2)
351
        assert F.allclose(r1, r2, rtol, atol)
352
353

        if not F.allclose(lhs_grad_1, lhs_grad_2, rtol, atol):
354
355
            print("left grad")
            _print_error(lhs_grad_1, lhs_grad_2)
356
        assert F.allclose(lhs_grad_1, lhs_grad_2, rtol, atol)
357

358
        if not F.allclose(rhs_grad_1, rhs_grad_2, rtol, atol):
359
360
            print("right grad")
            _print_error(rhs_grad_1, rhs_grad_2)
361
        assert F.allclose(rhs_grad_1, rhs_grad_2, rtol, atol)
362
363
364

    g = dgl.DGLGraph()
    g.add_nodes(20)
365
366
    # NOTE(zihao): add self-loop to avoid zero-degree nodes.
    g.add_edges(g.nodes(), g.nodes())
367
    for i in range(2, 18):
368
369
370
371
372
373
374
375
        g.add_edges(0, i)
        g.add_edges(1, i)
        g.add_edges(i, 18)
        g.add_edges(i, 19)
    g.add_edges(18, 0)
    g.add_edges(18, 1)
    g.add_edges(19, 0)
    g.add_edges(19, 1)
376
377
    g = g.to(F.ctx())
    nid = F.tensor([0, 1, 4, 5, 7, 12, 14, 15, 18, 19], g.idtype)
378
    target = ["u", "v", "e"]
379

380
381
382
    for lhs, rhs in product(target, target):
        if lhs == rhs:
            continue
383
384
        for binary_op in ["add", "sub", "mul", "div"]:
            for reducer in ["sum", "max", "min", "mean"]:
385
                for broadcast in ["none", lhs, rhs]:
386
                    for partial in [False, True]:
387
                        print(lhs, rhs, binary_op, reducer, broadcast, partial)
388
389
390
391
392
393
394
395
396
397
398
                        _test(
                            g,
                            lhs,
                            rhs,
                            binary_op,
                            reducer,
                            partial,
                            nid,
                            broadcast=broadcast,
                        )

399

nv-dlasalle's avatar
nv-dlasalle committed
400
@parametrize_idtype
401
@pytest.mark.parametrize("g", get_cases(["homo-zero-degree"]))
402
403
def test_mean_zero_degree(g, idtype):
    g = g.astype(idtype).to(F.ctx())
404
405
    g.ndata["h"] = F.ones((g.number_of_nodes(), 3))
    g.update_all(fn.copy_u("h", "m"), fn.mean("m", "x"))
406
407
    deg = F.asnumpy(g.in_degrees())
    v = F.tensor(np.where(deg == 0)[0])
408
409
    assert F.allclose(F.gather_row(g.ndata["x"], v), F.zeros((len(v), 3)))

410

411
if __name__ == "__main__":
VoVAllen's avatar
VoVAllen committed
412
413
    test_copy_src_reduce()
    test_copy_edge_reduce()
414
    test_all_binary_builtins()