test_sparse.py 14.6 KB
Newer Older
1
from distutils.version import LooseVersion
2
import random
3
4
import unittest

5
import backend as F
6
7
import numpy as np
import pytest
8
import torch
9
10
11
12
13
from test_utils import parametrize_idtype
from test_utils.graph_cases import get_cases

import dgl
from dgl.ops import edge_softmax, gsddmm, gspmm, segment_reduce
14

15
random.seed(42)
16
17
18
np.random.seed(42)

udf_msg = {
19
20
21
22
23
24
    "add": lambda edges: {"m": edges.src["x"] + edges.data["w"]},
    "sub": lambda edges: {"m": edges.src["x"] - edges.data["w"]},
    "mul": lambda edges: {"m": edges.src["x"] * edges.data["w"]},
    "div": lambda edges: {"m": edges.src["x"] / edges.data["w"]},
    "copy_lhs": lambda edges: {"m": edges.src["x"]},
    "copy_rhs": lambda edges: {"m": edges.data["w"]},
25
26
}

27

28
def select(target, src, edge, dst):
29
    if target == "u":
30
        return src
31
    elif target == "v":
32
        return dst
33
    elif target == "e":
34
35
        return edge

36

37
def binary_op(msg, x, y):
38
    if msg == "add":
39
        return x + y
40
    elif msg == "sub":
41
        return x - y
42
    elif msg == "mul":
43
        return x * y
44
    elif msg == "div":
45
        return x / y
46
    elif msg == "dot":
47
        return F.sum(x * y, -1, keepdims=True)
48
    elif msg == "copy_lhs":
49
        return x
50
    elif msg == "copy_rhs":
51
52
        return y

53

54
55
56
def edge_func(lhs_target, rhs_target, msg):
    def foo(edges):
        return {
57
            "m": binary_op(
58
                msg,
59
60
                select(lhs_target, edges.src, edges.data, edges.dst)["x"],
                select(rhs_target, edges.src, edges.data, edges.dst)["y"],
61
62
            )
        }
63

64
65
    return foo

66

67
udf_apply_edges = {
68
69
70
71
72
73
74
75
    lhs_target
    + "_"
    + msg
    + "_"
    + rhs_target: edge_func(lhs_target, rhs_target, msg)
    for lhs_target in ["u", "v", "e"]
    for rhs_target in ["u", "v", "e"]
    for msg in ["add", "sub", "mul", "div", "dot", "copy_lhs", "copy_rhs"]
76
77
78
}

udf_reduce = {
79
80
81
    "sum": lambda nodes: {"v": F.sum(nodes.mailbox["m"], 1)},
    "min": lambda nodes: {"v": F.min(nodes.mailbox["m"], 1)},
    "max": lambda nodes: {"v": F.max(nodes.mailbox["m"], 1)},
82
83
84
}

graphs = [
85
    #    dgl.rand_graph(30, 0),
86
    dgl.rand_graph(30, 100),
87
    dgl.rand_bipartite("_U", "_E", "_V", 30, 40, 300),
88
89
90
91
92
]

spmm_shapes = [
    ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),
    ((3, 3), (1, 3)),
93
94
    ((1,), (3,)),
    ((3,), (1,)),
95
    ((1,), (1,)),
96
    ((), ()),
97
98
99
100
101
102
103
]

sddmm_shapes = [
    ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),
    ((5, 3, 1, 7), (1, 3, 7, 7)),
    ((1, 3, 3), (4, 1, 3)),
    ((3,), (3,)),
104
    ((1,), (1,)),
105
106
]

107
edge_softmax_shapes = [(1,), (1, 3), (3, 4, 5)]
108

109
110
111
112
113
114
115

@pytest.mark.parametrize("g", graphs)
@pytest.mark.parametrize("shp", spmm_shapes)
@pytest.mark.parametrize(
    "msg", ["add", "sub", "mul", "div", "copy_lhs", "copy_rhs"]
)
@pytest.mark.parametrize("reducer", ["sum", "min", "max"])
nv-dlasalle's avatar
nv-dlasalle committed
116
@parametrize_idtype
117
118
def test_spmm(idtype, g, shp, msg, reducer):
    g = g.astype(idtype).to(F.ctx())
119
    print(g)
120
    print(g.idtype)
121

122
123
    hu = F.tensor(np.random.rand(*((g.number_of_src_nodes(),) + shp[0])) + 1)
    he = F.tensor(np.random.rand(*((g.number_of_edges(),) + shp[1])) + 1)
124
    print("u shape: {}, e shape: {}".format(F.shape(hu), F.shape(he)))
125

126
127
128
    g.srcdata["x"] = F.attach_grad(F.clone(hu))
    g.edata["w"] = F.attach_grad(F.clone(he))
    print("SpMM(message func: {}, reduce func: {})".format(msg, reducer))
129
130
131
132
133

    u = F.attach_grad(F.clone(hu))
    e = F.attach_grad(F.clone(he))
    with F.record_grad():
        v = gspmm(g, msg, reducer, u, e)
134
        if reducer in ["max", "min"]:
135
            v = F.replace_inf_with_zero(v)
136
137
        if g.number_of_edges() > 0:
            F.backward(F.reduce_sum(v))
138
            if msg != "copy_rhs":
139
                grad_u = F.grad(u)
140
            if msg != "copy_lhs":
141
142
143
144
145
                grad_e = F.grad(e)

    with F.record_grad():
        g.update_all(udf_msg[msg], udf_reduce[reducer])
        if g.number_of_edges() > 0:
146
            v1 = g.dstdata["v"]
147
            assert F.allclose(v, v1)
148
            print("forward passed")
149
150

            F.backward(F.reduce_sum(v1))
151
152
153
154
155
156
157
158
            if msg != "copy_rhs":
                if reducer in [
                    "min",
                    "max",
                ]:  # there might be some numerical errors
                    rate = F.reduce_sum(
                        F.abs(F.grad(g.srcdata["x"]) - grad_u)
                    ) / F.reduce_sum(F.abs(grad_u))
Zihao Ye's avatar
Zihao Ye committed
159
                    assert F.as_scalar(rate) < 1e-2, rate
160
                else:
161
162
163
164
165
166
                    assert F.allclose(F.grad(g.srcdata["x"]), grad_u)
            if msg != "copy_lhs":
                if reducer in ["min", "max"]:
                    rate = F.reduce_sum(
                        F.abs(F.grad(g.edata["w"]) - grad_e)
                    ) / F.reduce_sum(F.abs(grad_e))
Zihao Ye's avatar
Zihao Ye committed
167
                    assert F.as_scalar(rate) < 1e-2, rate
168
                else:
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
                    assert F.allclose(F.grad(g.edata["w"]), grad_e)
            print("backward passed")

    g.srcdata.pop("x")
    g.edata.pop("w")
    if "v" in g.dstdata:
        g.dstdata.pop("v")


@pytest.mark.parametrize("g", graphs)
@pytest.mark.parametrize("shp", sddmm_shapes)
@pytest.mark.parametrize("lhs_target", ["u", "v", "e"])
@pytest.mark.parametrize("rhs_target", ["u", "v", "e"])
@pytest.mark.parametrize(
    "msg", ["add", "sub", "mul", "div", "dot", "copy_lhs", "copy_rhs"]
)
nv-dlasalle's avatar
nv-dlasalle committed
185
@parametrize_idtype
186
187
188
189
def test_sddmm(g, shp, lhs_target, rhs_target, msg, idtype):
    if lhs_target == rhs_target:
        return
    g = g.astype(idtype).to(F.ctx())
190
191
    if dgl.backend.backend_name == "mxnet" and g.number_of_edges() == 0:
        pytest.skip()  # mxnet do not support zero shape tensor
192
    print(g)
193
194
195
196
197
198
    print(g.idtype)

    len_lhs = select(
        lhs_target,
        g.number_of_src_nodes(),
        g.number_of_edges(),
199
200
        g.number_of_dst_nodes(),
    )
201
202
203
204
205
    lhs_shp = (len_lhs,) + shp[0]
    len_rhs = select(
        rhs_target,
        g.number_of_src_nodes(),
        g.number_of_edges(),
206
207
        g.number_of_dst_nodes(),
    )
208
209
210
    rhs_shp = (len_rhs,) + shp[1]
    feat_lhs = F.tensor(np.random.rand(*lhs_shp) + 1)
    feat_rhs = F.tensor(np.random.rand(*rhs_shp) + 1)
211
212
213
214
215
216
217
218
219
220
221
222
    print(
        "lhs shape: {}, rhs shape: {}".format(
            F.shape(feat_lhs), F.shape(feat_rhs)
        )
    )

    lhs_frame = select(lhs_target, g.srcdata, g.edata, g.dstdata)
    rhs_frame = select(rhs_target, g.srcdata, g.edata, g.dstdata)
    lhs_frame["x"] = F.attach_grad(F.clone(feat_lhs))
    rhs_frame["y"] = F.attach_grad(F.clone(feat_rhs))
    msg_func = lhs_target + "_" + msg + "_" + rhs_target
    print("SDDMM(message func: {})".format(msg_func))
223
224
225
226

    lhs = F.attach_grad(F.clone(feat_lhs))
    rhs = F.attach_grad(F.clone(feat_rhs))
    with F.record_grad():
227
228
229
        e = gsddmm(
            g, msg, lhs, rhs, lhs_target=lhs_target, rhs_target=rhs_target
        )
230
231
232
233
234
235
236
        F.backward(F.reduce_sum(e))
        grad_lhs = F.grad(lhs)
        grad_rhs = F.grad(rhs)

    with F.record_grad():
        g.apply_edges(udf_apply_edges[msg_func])
        if g.number_of_edges() > 0:
237
            e1 = g.edata["m"]
238
            assert F.allclose(e, e1)
239
            print("forward passed")
240
241

            F.backward(F.reduce_sum(e1))
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
            if msg != "copy_rhs":
                assert F.allclose(F.grad(lhs_frame["x"]), grad_lhs)
            if msg != "copy_lhs":
                assert F.allclose(F.grad(rhs_frame["y"]), grad_rhs)
            print("backward passed")

    lhs_frame.pop("x")
    rhs_frame.pop("y")
    if "m" in g.edata:
        g.edata.pop("m")


@pytest.mark.parametrize("g", get_cases(["clique"]))
@pytest.mark.parametrize("norm_by", ["src", "dst"])
@pytest.mark.parametrize("shp", edge_softmax_shapes)
nv-dlasalle's avatar
nv-dlasalle committed
257
@parametrize_idtype
258
259
260
261
262
263
264
265
266
267
268
269
270
def test_edge_softmax(g, norm_by, shp, idtype):
    g = g.astype(idtype).to(F.ctx())
    edata = F.tensor(np.random.rand(g.number_of_edges(), *shp))
    e1 = F.attach_grad(F.clone(edata))

    with F.record_grad():
        score1 = edge_softmax(g, e1, norm_by=norm_by)
        F.backward(F.reduce_sum(score1))
        grad_edata = F.grad(e1)

    with F.record_grad():
        e2 = F.attach_grad(F.clone(edata))
        e2_2d = F.reshape(
271
272
273
274
            e2,
            (g.number_of_src_nodes(), g.number_of_dst_nodes(), *e2.shape[1:]),
        )
        if norm_by == "src":
275
276
            score2 = F.softmax(e2_2d, 1)
            score2 = F.reshape(score2, (-1, *e2.shape[1:]))
277
        if norm_by == "dst":
278
279
280
            score2 = F.softmax(e2_2d, 0)
            score2 = F.reshape(score2, (-1, *e2.shape[1:]))
        assert F.allclose(score1, score2)
281
        print("forward passed")
282
283
284

        F.backward(F.reduce_sum(score2))
        assert F.allclose(F.grad(e2), grad_edata)
285
286
        print("backward passed")

287

288
@pytest.mark.parametrize("reducer", ["sum", "max", "min", "mean"])
289
290
291
292
293
def test_segment_reduce(reducer):
    ctx = F.ctx()
    value = F.tensor(np.random.rand(10, 5))
    v1 = F.attach_grad(F.clone(value))
    v2 = F.attach_grad(F.clone(value))
294
    seglen = F.tensor([2, 3, 0, 4, 1, 0, 0])
295
    u = F.copy_to(F.arange(0, F.shape(value)[0], F.int32), ctx)
296
297
298
299
300
301
302
303
    v = F.repeat(
        F.copy_to(F.arange(0, len(seglen), F.int32), ctx), seglen, dim=0
    )

    num_nodes = {"_U": len(u), "_V": len(seglen)}
    g = dgl.convert.heterograph(
        {("_U", "_E", "_V"): (u, v)}, num_nodes_dict=num_nodes
    )
304
    with F.record_grad():
305
306
        rst1 = gspmm(g, "copy_lhs", reducer, v1, None)
        if reducer in ["max", "min"]:
307
            rst1 = F.replace_inf_with_zero(rst1)
308
309
310
311
312
313
314
        F.backward(F.reduce_sum(rst1))
        grad1 = F.grad(v1)

    with F.record_grad():
        rst2 = segment_reduce(seglen, v2, reducer=reducer)
        F.backward(F.reduce_sum(rst2))
        assert F.allclose(rst1, rst2)
315
        print("forward passed")
316
317
318

        grad2 = F.grad(v2)
        assert F.allclose(grad1, grad2)
319
320
        print("backward passed")

321

322
323
324
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
nv-dlasalle's avatar
nv-dlasalle committed
325
@parametrize_idtype
326
327
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
@pytest.mark.parametrize(
328
329
330
    "dtype, tol",
    [(torch.float16, 1e-2), (torch.bfloat16, 1e-2),
     (torch.float32, 3e-3), (torch.float64, 1e-4)],
331
)
332
def test_segment_mm(idtype, feat_size, dtype, tol):
333
    if F._default_context_str == "cpu" and dtype in (torch.float16, torch.bfloat16):
334
        pytest.skip(
335
336
337
338
339
340
341
            "Only support float32 and float64 on CPU."
        )
    if F._default_context_str == "gpu" \
        and LooseVersion(torch.version.cuda) < LooseVersion("11.0") \
        and dtype == torch.bfloat16:
        pytest.skip(
            "BF16 requires CUDA >= 11.0."
342
        )
343
344
    dev = F.ctx()
    # input
345
    a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
346
    a.requires_grad_()
347
348
349
350
351
    b = (
        torch.tensor(np.random.rand(10, feat_size, feat_size + 1))
        .to(dev)
        .to(dtype)
    )
352
    b.requires_grad_()
353
    seglen_a = torch.tensor([10, 15, 8, 0, 1, 9, 18, 24, 15, 0]).to(idtype)
354
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
355
356
357
358
359
360
361
362
363
    # compute
    c = dgl.ops.segment_mm(a, b, seglen_a)
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = []
    off = 0
    for i, l in enumerate(seglen_a):
364
        c_t.append(a[off : off + l] @ b[i])
365
        off += l
366
    c_t = torch.cat(c_t).to(dtype)
367
368
369
370
371
372
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

373
374
375
    assert torch.allclose(c, c_t, atol=tol, rtol=tol)
    assert torch.allclose(da, da_t, atol=tol, rtol=tol)
    assert torch.allclose(db, db_t, atol=tol, rtol=tol)
376

377
378
379
380
381

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
382
383
384
385
386
387
388
389
390
391
392
393
@pytest.mark.parametrize(
    "dtype, tol",
    [(torch.float16, 1e-2), (torch.bfloat16, 2e-2),
     (torch.float32, 3e-3), (torch.float64, 1e-4)]
)
def test_gather_mm_idx_b(feat_size, dtype, tol):
    if F._default_context_str == "cpu" and dtype in (torch.float16, torch.bfloat16):
        pytest.skip("Only support float32 and float64 on CPU.")
    if F._default_context_str == "gpu" \
        and LooseVersion(torch.version.cuda) < LooseVersion("11.0") \
        and dtype == torch.bfloat16:
        pytest.skip("BF16 requires CUDA >= 11.0.")
394

395
396
    dev = F.ctx()
    # input
397
    a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
398
    a.requires_grad_()
399
    b = torch.tensor(np.random.rand(10, feat_size, feat_size + 1)).to(dev).to(dtype)
400
401
    b.requires_grad_()
    idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev).long()
402
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
403
404
405
406
407
408
409
410
411
412
413
414
415
    # compute
    c = dgl.ops.gather_mm(a, b, idx_b=idx)
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = torch.bmm(a.unsqueeze(1), b[idx]).squeeze(1)
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

416
417
418
    assert torch.allclose(c, c_t, atol=tol, rtol=tol)
    assert torch.allclose(da, da_t, atol=tol, rtol=tol)
    assert torch.allclose(db, db_t, atol=tol, rtol=tol)
419

420
421
422
423

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
nv-dlasalle's avatar
nv-dlasalle committed
424
@parametrize_idtype
425
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
426
427
428
def _test_gather_mm_idx_a(idtype, feat_size):
    # TODO(minjie): currently disabled due to bugs in the CUDA kernel. Need to fix it later.
    import torch
429

430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
    dev = F.ctx()
    # input
    a = torch.tensor(np.random.rand(10, feat_size)).to(dev)
    a.requires_grad_()
    b = torch.tensor(np.random.rand(100, feat_size, feat_size + 1)).to(dev)
    b.requires_grad_()
    idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev)
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev)
    # compute
    c = dgl.ops.gather_mm(a, b, idx_a=idx)
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = torch.bmm(a[idx].unsqueeze(1), b).squeeze(1)
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

    assert torch.allclose(c, c_t, atol=1e-4, rtol=1e-4)
    assert torch.allclose(da, da_t, atol=1e-4, rtol=1e-4)
    assert torch.allclose(db, db_t, atol=1e-4, rtol=1e-4)
454

455
456
457
458
459
460
461

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@unittest.skipIf(
    F._default_context_str == "gpu", reason="Libxsmm only fit in CPU."
)
462
463
def test_use_libxsmm_switch():
    import torch
464

465
466
467
468
469
470
471
472
473
474
475
476
    g = dgl.graph(([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]))
    x = torch.ones(3, 2, requires_grad=True)
    y = torch.arange(1, 13).float().view(6, 2).requires_grad_()

    assert dgl.is_libxsmm_enabled()
    dgl.ops.u_mul_e_sum(g, x, y)
    dgl.use_libxsmm(False)
    assert ~dgl.is_libxsmm_enabled()
    dgl.ops.u_mul_e_sum(g, x, y)
    dgl.use_libxsmm(True)
    assert dgl.is_libxsmm_enabled()
    dgl.ops.u_mul_e_sum(g, x, y)