test_sparse.py 15.9 KB
Newer Older
1
import random
2
3
import unittest

4
import backend as F
Xin Yao's avatar
Xin Yao committed
5
6

import dgl
7
8
import numpy as np
import pytest
9
import torch
Xin Yao's avatar
Xin Yao committed
10
from dgl.ops import edge_softmax, gsddmm, gspmm, segment_reduce
11
12
13
from test_utils import parametrize_idtype
from test_utils.graph_cases import get_cases

14
random.seed(42)
15
16
17
np.random.seed(42)

udf_msg = {
18
19
20
21
22
23
    "add": lambda edges: {"m": edges.src["x"] + edges.data["w"]},
    "sub": lambda edges: {"m": edges.src["x"] - edges.data["w"]},
    "mul": lambda edges: {"m": edges.src["x"] * edges.data["w"]},
    "div": lambda edges: {"m": edges.src["x"] / edges.data["w"]},
    "copy_lhs": lambda edges: {"m": edges.src["x"]},
    "copy_rhs": lambda edges: {"m": edges.data["w"]},
24
25
}

26

27
def select(target, src, edge, dst):
28
    if target == "u":
29
        return src
30
    elif target == "v":
31
        return dst
32
    elif target == "e":
33
34
        return edge

35

36
def binary_op(msg, x, y):
37
    if msg == "add":
38
        return x + y
39
    elif msg == "sub":
40
        return x - y
41
    elif msg == "mul":
42
        return x * y
43
    elif msg == "div":
44
        return x / y
45
    elif msg == "dot":
46
        return F.sum(x * y, -1, keepdims=True)
47
    elif msg == "copy_lhs":
48
        return x
49
    elif msg == "copy_rhs":
50
51
        return y

52

53
54
55
def edge_func(lhs_target, rhs_target, msg):
    def foo(edges):
        return {
56
            "m": binary_op(
57
                msg,
58
59
                select(lhs_target, edges.src, edges.data, edges.dst)["x"],
                select(rhs_target, edges.src, edges.data, edges.dst)["y"],
60
61
            )
        }
62

63
64
    return foo

65

66
udf_apply_edges = {
67
68
69
70
71
72
73
74
    lhs_target
    + "_"
    + msg
    + "_"
    + rhs_target: edge_func(lhs_target, rhs_target, msg)
    for lhs_target in ["u", "v", "e"]
    for rhs_target in ["u", "v", "e"]
    for msg in ["add", "sub", "mul", "div", "dot", "copy_lhs", "copy_rhs"]
75
76
77
}

udf_reduce = {
78
79
80
    "sum": lambda nodes: {"v": F.sum(nodes.mailbox["m"], 1)},
    "min": lambda nodes: {"v": F.min(nodes.mailbox["m"], 1)},
    "max": lambda nodes: {"v": F.max(nodes.mailbox["m"], 1)},
81
82
83
}

graphs = [
84
    #    dgl.rand_graph(30, 0),
85
    dgl.rand_graph(30, 100),
86
    dgl.rand_bipartite("_U", "_E", "_V", 30, 40, 300),
87
88
89
90
91
]

spmm_shapes = [
    ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),
    ((3, 3), (1, 3)),
92
93
    ((1,), (3,)),
    ((3,), (1,)),
94
    ((1,), (1,)),
95
    ((), ()),
96
97
98
99
100
101
102
]

sddmm_shapes = [
    ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),
    ((5, 3, 1, 7), (1, 3, 7, 7)),
    ((1, 3, 3), (4, 1, 3)),
    ((3,), (3,)),
103
    ((1,), (1,)),
104
105
]

106
edge_softmax_shapes = [(1,), (1, 3), (3, 4, 5)]
107

108
109
110
111
112
113
114

@pytest.mark.parametrize("g", graphs)
@pytest.mark.parametrize("shp", spmm_shapes)
@pytest.mark.parametrize(
    "msg", ["add", "sub", "mul", "div", "copy_lhs", "copy_rhs"]
)
@pytest.mark.parametrize("reducer", ["sum", "min", "max"])
nv-dlasalle's avatar
nv-dlasalle committed
115
@parametrize_idtype
116
117
def test_spmm(idtype, g, shp, msg, reducer):
    g = g.astype(idtype).to(F.ctx())
118
    print(g)
119
    print(g.idtype)
120

121
122
    hu = F.tensor(np.random.rand(*((g.number_of_src_nodes(),) + shp[0])) + 1)
    he = F.tensor(np.random.rand(*((g.number_of_edges(),) + shp[1])) + 1)
123
    print("u shape: {}, e shape: {}".format(F.shape(hu), F.shape(he)))
124

125
126
127
    g.srcdata["x"] = F.attach_grad(F.clone(hu))
    g.edata["w"] = F.attach_grad(F.clone(he))
    print("SpMM(message func: {}, reduce func: {})".format(msg, reducer))
128
129
130
131
132

    u = F.attach_grad(F.clone(hu))
    e = F.attach_grad(F.clone(he))
    with F.record_grad():
        v = gspmm(g, msg, reducer, u, e)
133
        if reducer in ["max", "min"]:
134
            v = F.replace_inf_with_zero(v)
135
136
        if g.number_of_edges() > 0:
            F.backward(F.reduce_sum(v))
137
            if msg != "copy_rhs":
138
                grad_u = F.grad(u)
139
            if msg != "copy_lhs":
140
141
142
143
144
                grad_e = F.grad(e)

    with F.record_grad():
        g.update_all(udf_msg[msg], udf_reduce[reducer])
        if g.number_of_edges() > 0:
145
            v1 = g.dstdata["v"]
146
            assert F.allclose(v, v1)
147
            print("forward passed")
148
149

            F.backward(F.reduce_sum(v1))
150
151
152
153
154
155
156
157
            if msg != "copy_rhs":
                if reducer in [
                    "min",
                    "max",
                ]:  # there might be some numerical errors
                    rate = F.reduce_sum(
                        F.abs(F.grad(g.srcdata["x"]) - grad_u)
                    ) / F.reduce_sum(F.abs(grad_u))
Zihao Ye's avatar
Zihao Ye committed
158
                    assert F.as_scalar(rate) < 1e-2, rate
159
                else:
160
161
162
163
164
165
                    assert F.allclose(F.grad(g.srcdata["x"]), grad_u)
            if msg != "copy_lhs":
                if reducer in ["min", "max"]:
                    rate = F.reduce_sum(
                        F.abs(F.grad(g.edata["w"]) - grad_e)
                    ) / F.reduce_sum(F.abs(grad_e))
Zihao Ye's avatar
Zihao Ye committed
166
                    assert F.as_scalar(rate) < 1e-2, rate
167
                else:
168
169
170
171
172
173
174
175
176
                    assert F.allclose(F.grad(g.edata["w"]), grad_e)
            print("backward passed")

    g.srcdata.pop("x")
    g.edata.pop("w")
    if "v" in g.dstdata:
        g.dstdata.pop("v")


177
178
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch",
Xin Yao's avatar
Xin Yao committed
179
    reason="Only support PyTorch for now.",
180
181
182
)
@unittest.skipIf(
    F._default_context_str == "cpu",
Xin Yao's avatar
Xin Yao committed
183
    reason="Don't support half precision on CPU.",
184
185
186
187
)
@parametrize_idtype
@pytest.mark.parametrize(
    "dtype, rtol, atol",
Xin Yao's avatar
Xin Yao committed
188
    [(torch.float16, 1e-3, 0.5), (torch.bfloat16, 4e-3, 2.0)],
189
190
)
def test_half_spmm(idtype, dtype, rtol, atol):
Xin Yao's avatar
Xin Yao committed
191
192
    if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
        pytest.skip("BF16 is not supported.")
193
194

    # make sure the spmm result is < 512 to match the rtol/atol we set.
Xin Yao's avatar
Xin Yao committed
195
196
197
198
199
    g = dgl.graph(
        (torch.arange(900), torch.tensor([0] * 900)),
        idtype=idtype,
        device=F.ctx(),
    )
200
201
202
203
    feat_fp32 = torch.rand((g.num_src_nodes(), 32)).to(0)
    feat_half = feat_fp32.to(dtype)

    # test SpMMCSR
Xin Yao's avatar
Xin Yao committed
204
    g = g.formats(["csc"])
205
206
207
208
209
210
211
212
213
214
215
216
    res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0]
    res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float()
    assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol)

    # test SpMMCOO
    # TODO(Xin): half-precision SpMMCoo is temporally disabled.
    # g = g.formats(['coo'])
    # res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0]
    # res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float()
    # assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol)


217
218
219
220
221
222
223
@pytest.mark.parametrize("g", graphs)
@pytest.mark.parametrize("shp", sddmm_shapes)
@pytest.mark.parametrize("lhs_target", ["u", "v", "e"])
@pytest.mark.parametrize("rhs_target", ["u", "v", "e"])
@pytest.mark.parametrize(
    "msg", ["add", "sub", "mul", "div", "dot", "copy_lhs", "copy_rhs"]
)
nv-dlasalle's avatar
nv-dlasalle committed
224
@parametrize_idtype
225
226
227
228
def test_sddmm(g, shp, lhs_target, rhs_target, msg, idtype):
    if lhs_target == rhs_target:
        return
    g = g.astype(idtype).to(F.ctx())
229
230
    if dgl.backend.backend_name == "mxnet" and g.number_of_edges() == 0:
        pytest.skip()  # mxnet do not support zero shape tensor
231
    print(g)
232
233
234
235
236
237
    print(g.idtype)

    len_lhs = select(
        lhs_target,
        g.number_of_src_nodes(),
        g.number_of_edges(),
238
239
        g.number_of_dst_nodes(),
    )
240
241
242
243
244
    lhs_shp = (len_lhs,) + shp[0]
    len_rhs = select(
        rhs_target,
        g.number_of_src_nodes(),
        g.number_of_edges(),
245
246
        g.number_of_dst_nodes(),
    )
247
248
249
    rhs_shp = (len_rhs,) + shp[1]
    feat_lhs = F.tensor(np.random.rand(*lhs_shp) + 1)
    feat_rhs = F.tensor(np.random.rand(*rhs_shp) + 1)
250
251
252
253
254
255
256
257
258
259
260
261
    print(
        "lhs shape: {}, rhs shape: {}".format(
            F.shape(feat_lhs), F.shape(feat_rhs)
        )
    )

    lhs_frame = select(lhs_target, g.srcdata, g.edata, g.dstdata)
    rhs_frame = select(rhs_target, g.srcdata, g.edata, g.dstdata)
    lhs_frame["x"] = F.attach_grad(F.clone(feat_lhs))
    rhs_frame["y"] = F.attach_grad(F.clone(feat_rhs))
    msg_func = lhs_target + "_" + msg + "_" + rhs_target
    print("SDDMM(message func: {})".format(msg_func))
262
263
264
265

    lhs = F.attach_grad(F.clone(feat_lhs))
    rhs = F.attach_grad(F.clone(feat_rhs))
    with F.record_grad():
266
267
268
        e = gsddmm(
            g, msg, lhs, rhs, lhs_target=lhs_target, rhs_target=rhs_target
        )
269
270
271
272
273
274
275
        F.backward(F.reduce_sum(e))
        grad_lhs = F.grad(lhs)
        grad_rhs = F.grad(rhs)

    with F.record_grad():
        g.apply_edges(udf_apply_edges[msg_func])
        if g.number_of_edges() > 0:
276
            e1 = g.edata["m"]
277
            assert F.allclose(e, e1)
278
            print("forward passed")
279
280

            F.backward(F.reduce_sum(e1))
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
            if msg != "copy_rhs":
                assert F.allclose(F.grad(lhs_frame["x"]), grad_lhs)
            if msg != "copy_lhs":
                assert F.allclose(F.grad(rhs_frame["y"]), grad_rhs)
            print("backward passed")

    lhs_frame.pop("x")
    rhs_frame.pop("y")
    if "m" in g.edata:
        g.edata.pop("m")


@pytest.mark.parametrize("g", get_cases(["clique"]))
@pytest.mark.parametrize("norm_by", ["src", "dst"])
@pytest.mark.parametrize("shp", edge_softmax_shapes)
nv-dlasalle's avatar
nv-dlasalle committed
296
@parametrize_idtype
297
298
299
300
301
302
303
304
305
306
307
308
309
def test_edge_softmax(g, norm_by, shp, idtype):
    g = g.astype(idtype).to(F.ctx())
    edata = F.tensor(np.random.rand(g.number_of_edges(), *shp))
    e1 = F.attach_grad(F.clone(edata))

    with F.record_grad():
        score1 = edge_softmax(g, e1, norm_by=norm_by)
        F.backward(F.reduce_sum(score1))
        grad_edata = F.grad(e1)

    with F.record_grad():
        e2 = F.attach_grad(F.clone(edata))
        e2_2d = F.reshape(
310
311
312
313
            e2,
            (g.number_of_src_nodes(), g.number_of_dst_nodes(), *e2.shape[1:]),
        )
        if norm_by == "src":
314
315
            score2 = F.softmax(e2_2d, 1)
            score2 = F.reshape(score2, (-1, *e2.shape[1:]))
316
        if norm_by == "dst":
317
318
319
            score2 = F.softmax(e2_2d, 0)
            score2 = F.reshape(score2, (-1, *e2.shape[1:]))
        assert F.allclose(score1, score2)
320
        print("forward passed")
321
322
323

        F.backward(F.reduce_sum(score2))
        assert F.allclose(F.grad(e2), grad_edata)
324
325
        print("backward passed")

326

327
@pytest.mark.parametrize("reducer", ["sum", "max", "min", "mean"])
328
329
330
331
332
def test_segment_reduce(reducer):
    ctx = F.ctx()
    value = F.tensor(np.random.rand(10, 5))
    v1 = F.attach_grad(F.clone(value))
    v2 = F.attach_grad(F.clone(value))
333
    seglen = F.tensor([2, 3, 0, 4, 1, 0, 0])
334
    u = F.copy_to(F.arange(0, F.shape(value)[0], F.int32), ctx)
335
336
337
338
339
340
341
342
    v = F.repeat(
        F.copy_to(F.arange(0, len(seglen), F.int32), ctx), seglen, dim=0
    )

    num_nodes = {"_U": len(u), "_V": len(seglen)}
    g = dgl.convert.heterograph(
        {("_U", "_E", "_V"): (u, v)}, num_nodes_dict=num_nodes
    )
343
    with F.record_grad():
344
345
        rst1 = gspmm(g, "copy_lhs", reducer, v1, None)
        if reducer in ["max", "min"]:
346
            rst1 = F.replace_inf_with_zero(rst1)
347
348
349
350
351
352
353
        F.backward(F.reduce_sum(rst1))
        grad1 = F.grad(v1)

    with F.record_grad():
        rst2 = segment_reduce(seglen, v2, reducer=reducer)
        F.backward(F.reduce_sum(rst2))
        assert F.allclose(rst1, rst2)
354
        print("forward passed")
355
356
357

        grad2 = F.grad(v2)
        assert F.allclose(grad1, grad2)
358
359
        print("backward passed")

360

361
362
363
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
nv-dlasalle's avatar
nv-dlasalle committed
364
@parametrize_idtype
365
366
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
@pytest.mark.parametrize(
367
    "dtype, tol",
Xin Yao's avatar
Xin Yao committed
368
369
370
371
372
373
    [
        (torch.float16, 1e-2),
        (torch.bfloat16, 1e-2),
        (torch.float32, 3e-3),
        (torch.float64, 1e-4),
    ],
374
)
375
def test_segment_mm(idtype, feat_size, dtype, tol):
Xin Yao's avatar
Xin Yao committed
376
377
378
379
380
381
382
383
384
385
386
    if F._default_context_str == "cpu" and dtype in (
        torch.float16,
        torch.bfloat16,
    ):
        pytest.skip("Only support float32 and float64 on CPU.")
    if (
        F._default_context_str == "gpu"
        and dtype == torch.bfloat16
        and not torch.cuda.is_bf16_supported()
    ):
        pytest.skip("BF16 is not supported.")
387
388
    dev = F.ctx()
    # input
389
    a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
390
    a.requires_grad_()
391
392
393
394
395
    b = (
        torch.tensor(np.random.rand(10, feat_size, feat_size + 1))
        .to(dev)
        .to(dtype)
    )
396
    b.requires_grad_()
397
    seglen_a = torch.tensor([10, 15, 8, 0, 1, 9, 18, 24, 15, 0]).to(idtype)
398
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
399
400
401
402
403
404
405
406
407
    # compute
    c = dgl.ops.segment_mm(a, b, seglen_a)
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = []
    off = 0
    for i, l in enumerate(seglen_a):
408
        c_t.append(a[off : off + l] @ b[i])
409
        off += l
410
    c_t = torch.cat(c_t).to(dtype)
411
412
413
414
415
416
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

417
418
419
    assert torch.allclose(c, c_t, atol=tol, rtol=tol)
    assert torch.allclose(da, da_t, atol=tol, rtol=tol)
    assert torch.allclose(db, db_t, atol=tol, rtol=tol)
420

421
422
423
424
425

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
426
427
@pytest.mark.parametrize(
    "dtype, tol",
Xin Yao's avatar
Xin Yao committed
428
429
430
431
432
433
    [
        (torch.float16, 1e-2),
        (torch.bfloat16, 2e-2),
        (torch.float32, 3e-3),
        (torch.float64, 1e-4),
    ],
434
435
)
def test_gather_mm_idx_b(feat_size, dtype, tol):
Xin Yao's avatar
Xin Yao committed
436
437
438
439
    if F._default_context_str == "cpu" and dtype in (
        torch.float16,
        torch.bfloat16,
    ):
440
        pytest.skip("Only support float32 and float64 on CPU.")
Xin Yao's avatar
Xin Yao committed
441
442
443
444
445
446
    if (
        F._default_context_str == "gpu"
        and dtype == torch.bfloat16
        and not torch.cuda.is_bf16_supported()
    ):
        pytest.skip("BF16 is not supported.")
447

448
449
    dev = F.ctx()
    # input
450
    a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
451
    a.requires_grad_()
Xin Yao's avatar
Xin Yao committed
452
453
454
455
456
    b = (
        torch.tensor(np.random.rand(10, feat_size, feat_size + 1))
        .to(dev)
        .to(dtype)
    )
457
458
    b.requires_grad_()
    idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev).long()
459
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
460
461
462
463
464
465
466
467
468
469
470
471
472
    # compute
    c = dgl.ops.gather_mm(a, b, idx_b=idx)
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = torch.bmm(a.unsqueeze(1), b[idx]).squeeze(1)
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

473
474
475
    assert torch.allclose(c, c_t, atol=tol, rtol=tol)
    assert torch.allclose(da, da_t, atol=tol, rtol=tol)
    assert torch.allclose(db, db_t, atol=tol, rtol=tol)
476

477
478
479
480

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
nv-dlasalle's avatar
nv-dlasalle committed
481
@parametrize_idtype
482
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
483
484
485
def _test_gather_mm_idx_a(idtype, feat_size):
    # TODO(minjie): currently disabled due to bugs in the CUDA kernel. Need to fix it later.
    import torch
486

487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    dev = F.ctx()
    # input
    a = torch.tensor(np.random.rand(10, feat_size)).to(dev)
    a.requires_grad_()
    b = torch.tensor(np.random.rand(100, feat_size, feat_size + 1)).to(dev)
    b.requires_grad_()
    idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev)
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev)
    # compute
    c = dgl.ops.gather_mm(a, b, idx_a=idx)
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = torch.bmm(a[idx].unsqueeze(1), b).squeeze(1)
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

    assert torch.allclose(c, c_t, atol=1e-4, rtol=1e-4)
    assert torch.allclose(da, da_t, atol=1e-4, rtol=1e-4)
    assert torch.allclose(db, db_t, atol=1e-4, rtol=1e-4)
511

512
513
514
515
516
517
518

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@unittest.skipIf(
    F._default_context_str == "gpu", reason="Libxsmm only fit in CPU."
)
519
520
def test_use_libxsmm_switch():
    import torch
521

522
523
524
525
526
527
528
529
530
531
532
533
    g = dgl.graph(([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]))
    x = torch.ones(3, 2, requires_grad=True)
    y = torch.arange(1, 13).float().view(6, 2).requires_grad_()

    assert dgl.is_libxsmm_enabled()
    dgl.ops.u_mul_e_sum(g, x, y)
    dgl.use_libxsmm(False)
    assert ~dgl.is_libxsmm_enabled()
    dgl.ops.u_mul_e_sum(g, x, y)
    dgl.use_libxsmm(True)
    assert dgl.is_libxsmm_enabled()
    dgl.ops.u_mul_e_sum(g, x, y)