test_sparse.py 15.9 KB
Newer Older
1
import random
2
3
import unittest

4
import backend as F
5
6
import numpy as np
import pytest
7
import torch
8
9
10
11
12
from test_utils import parametrize_idtype
from test_utils.graph_cases import get_cases

import dgl
from dgl.ops import edge_softmax, gsddmm, gspmm, segment_reduce
13
from dgl.utils import version
14

15
random.seed(42)
16
17
18
np.random.seed(42)

udf_msg = {
19
20
21
22
23
24
    "add": lambda edges: {"m": edges.src["x"] + edges.data["w"]},
    "sub": lambda edges: {"m": edges.src["x"] - edges.data["w"]},
    "mul": lambda edges: {"m": edges.src["x"] * edges.data["w"]},
    "div": lambda edges: {"m": edges.src["x"] / edges.data["w"]},
    "copy_lhs": lambda edges: {"m": edges.src["x"]},
    "copy_rhs": lambda edges: {"m": edges.data["w"]},
25
26
}

27

28
def select(target, src, edge, dst):
29
    if target == "u":
30
        return src
31
    elif target == "v":
32
        return dst
33
    elif target == "e":
34
35
        return edge

36

37
def binary_op(msg, x, y):
38
    if msg == "add":
39
        return x + y
40
    elif msg == "sub":
41
        return x - y
42
    elif msg == "mul":
43
        return x * y
44
    elif msg == "div":
45
        return x / y
46
    elif msg == "dot":
47
        return F.sum(x * y, -1, keepdims=True)
48
    elif msg == "copy_lhs":
49
        return x
50
    elif msg == "copy_rhs":
51
52
        return y

53

54
55
56
def edge_func(lhs_target, rhs_target, msg):
    def foo(edges):
        return {
57
            "m": binary_op(
58
                msg,
59
60
                select(lhs_target, edges.src, edges.data, edges.dst)["x"],
                select(rhs_target, edges.src, edges.data, edges.dst)["y"],
61
62
            )
        }
63

64
65
    return foo

66

67
udf_apply_edges = {
68
69
70
71
72
73
74
75
    lhs_target
    + "_"
    + msg
    + "_"
    + rhs_target: edge_func(lhs_target, rhs_target, msg)
    for lhs_target in ["u", "v", "e"]
    for rhs_target in ["u", "v", "e"]
    for msg in ["add", "sub", "mul", "div", "dot", "copy_lhs", "copy_rhs"]
76
77
78
}

udf_reduce = {
79
80
81
    "sum": lambda nodes: {"v": F.sum(nodes.mailbox["m"], 1)},
    "min": lambda nodes: {"v": F.min(nodes.mailbox["m"], 1)},
    "max": lambda nodes: {"v": F.max(nodes.mailbox["m"], 1)},
82
83
84
}

graphs = [
85
    #    dgl.rand_graph(30, 0),
86
    dgl.rand_graph(30, 100),
87
    dgl.rand_bipartite("_U", "_E", "_V", 30, 40, 300),
88
89
90
91
92
]

spmm_shapes = [
    ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),
    ((3, 3), (1, 3)),
93
94
    ((1,), (3,)),
    ((3,), (1,)),
95
    ((1,), (1,)),
96
    ((), ()),
97
98
99
100
101
102
103
]

sddmm_shapes = [
    ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),
    ((5, 3, 1, 7), (1, 3, 7, 7)),
    ((1, 3, 3), (4, 1, 3)),
    ((3,), (3,)),
104
    ((1,), (1,)),
105
106
]

107
edge_softmax_shapes = [(1,), (1, 3), (3, 4, 5)]
108

109
110
111
112
113
114
115

@pytest.mark.parametrize("g", graphs)
@pytest.mark.parametrize("shp", spmm_shapes)
@pytest.mark.parametrize(
    "msg", ["add", "sub", "mul", "div", "copy_lhs", "copy_rhs"]
)
@pytest.mark.parametrize("reducer", ["sum", "min", "max"])
nv-dlasalle's avatar
nv-dlasalle committed
116
@parametrize_idtype
117
118
def test_spmm(idtype, g, shp, msg, reducer):
    g = g.astype(idtype).to(F.ctx())
119
    print(g)
120
    print(g.idtype)
121

122
123
    hu = F.tensor(np.random.rand(*((g.number_of_src_nodes(),) + shp[0])) + 1)
    he = F.tensor(np.random.rand(*((g.number_of_edges(),) + shp[1])) + 1)
124
    print("u shape: {}, e shape: {}".format(F.shape(hu), F.shape(he)))
125

126
127
128
    g.srcdata["x"] = F.attach_grad(F.clone(hu))
    g.edata["w"] = F.attach_grad(F.clone(he))
    print("SpMM(message func: {}, reduce func: {})".format(msg, reducer))
129
130
131
132
133

    u = F.attach_grad(F.clone(hu))
    e = F.attach_grad(F.clone(he))
    with F.record_grad():
        v = gspmm(g, msg, reducer, u, e)
134
        if reducer in ["max", "min"]:
135
            v = F.replace_inf_with_zero(v)
136
137
        if g.number_of_edges() > 0:
            F.backward(F.reduce_sum(v))
138
            if msg != "copy_rhs":
139
                grad_u = F.grad(u)
140
            if msg != "copy_lhs":
141
142
143
144
145
                grad_e = F.grad(e)

    with F.record_grad():
        g.update_all(udf_msg[msg], udf_reduce[reducer])
        if g.number_of_edges() > 0:
146
            v1 = g.dstdata["v"]
147
            assert F.allclose(v, v1)
148
            print("forward passed")
149
150

            F.backward(F.reduce_sum(v1))
151
152
153
154
155
156
157
158
            if msg != "copy_rhs":
                if reducer in [
                    "min",
                    "max",
                ]:  # there might be some numerical errors
                    rate = F.reduce_sum(
                        F.abs(F.grad(g.srcdata["x"]) - grad_u)
                    ) / F.reduce_sum(F.abs(grad_u))
Zihao Ye's avatar
Zihao Ye committed
159
                    assert F.as_scalar(rate) < 1e-2, rate
160
                else:
161
162
163
164
165
166
                    assert F.allclose(F.grad(g.srcdata["x"]), grad_u)
            if msg != "copy_lhs":
                if reducer in ["min", "max"]:
                    rate = F.reduce_sum(
                        F.abs(F.grad(g.edata["w"]) - grad_e)
                    ) / F.reduce_sum(F.abs(grad_e))
Zihao Ye's avatar
Zihao Ye committed
167
                    assert F.as_scalar(rate) < 1e-2, rate
168
                else:
169
170
171
172
173
174
175
176
177
                    assert F.allclose(F.grad(g.edata["w"]), grad_e)
            print("backward passed")

    g.srcdata.pop("x")
    g.edata.pop("w")
    if "v" in g.dstdata:
        g.dstdata.pop("v")


178
179
180
181
182
183
184
185
186
187
188
189
190
191
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch",
    reason="Only support PyTorch for now."
)
@unittest.skipIf(
    F._default_context_str == "cpu",
    reason="Don't support half precision on CPU."
)
@parametrize_idtype
@pytest.mark.parametrize(
    "dtype, rtol, atol",
    [(torch.float16, 1e-3, 0.5), (torch.bfloat16, 4e-3, 2.)]
)
def test_half_spmm(idtype, dtype, rtol, atol):
192
    if version.parse(torch.version.cuda) < version.parse("11.0") \
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        and dtype == torch.bfloat16:
        pytest.skip("BF16 requires CUDA >= 11.0.")

    # make sure the spmm result is < 512 to match the rtol/atol we set.
    g = dgl.graph((torch.arange(900), torch.tensor([0] * 900)),
                  idtype=idtype, device=F.ctx())
    feat_fp32 = torch.rand((g.num_src_nodes(), 32)).to(0)
    feat_half = feat_fp32.to(dtype)

    # test SpMMCSR
    g = g.formats(['csc'])
    res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0]
    res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float()
    assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol)

    # test SpMMCOO
    # TODO(Xin): half-precision SpMMCoo is temporally disabled.
    # g = g.formats(['coo'])
    # res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0]
    # res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float()
    # assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol)


216
217
218
219
220
221
222
@pytest.mark.parametrize("g", graphs)
@pytest.mark.parametrize("shp", sddmm_shapes)
@pytest.mark.parametrize("lhs_target", ["u", "v", "e"])
@pytest.mark.parametrize("rhs_target", ["u", "v", "e"])
@pytest.mark.parametrize(
    "msg", ["add", "sub", "mul", "div", "dot", "copy_lhs", "copy_rhs"]
)
nv-dlasalle's avatar
nv-dlasalle committed
223
@parametrize_idtype
224
225
226
227
def test_sddmm(g, shp, lhs_target, rhs_target, msg, idtype):
    if lhs_target == rhs_target:
        return
    g = g.astype(idtype).to(F.ctx())
228
229
    if dgl.backend.backend_name == "mxnet" and g.number_of_edges() == 0:
        pytest.skip()  # mxnet do not support zero shape tensor
230
    print(g)
231
232
233
234
235
236
    print(g.idtype)

    len_lhs = select(
        lhs_target,
        g.number_of_src_nodes(),
        g.number_of_edges(),
237
238
        g.number_of_dst_nodes(),
    )
239
240
241
242
243
    lhs_shp = (len_lhs,) + shp[0]
    len_rhs = select(
        rhs_target,
        g.number_of_src_nodes(),
        g.number_of_edges(),
244
245
        g.number_of_dst_nodes(),
    )
246
247
248
    rhs_shp = (len_rhs,) + shp[1]
    feat_lhs = F.tensor(np.random.rand(*lhs_shp) + 1)
    feat_rhs = F.tensor(np.random.rand(*rhs_shp) + 1)
249
250
251
252
253
254
255
256
257
258
259
260
    print(
        "lhs shape: {}, rhs shape: {}".format(
            F.shape(feat_lhs), F.shape(feat_rhs)
        )
    )

    lhs_frame = select(lhs_target, g.srcdata, g.edata, g.dstdata)
    rhs_frame = select(rhs_target, g.srcdata, g.edata, g.dstdata)
    lhs_frame["x"] = F.attach_grad(F.clone(feat_lhs))
    rhs_frame["y"] = F.attach_grad(F.clone(feat_rhs))
    msg_func = lhs_target + "_" + msg + "_" + rhs_target
    print("SDDMM(message func: {})".format(msg_func))
261
262
263
264

    lhs = F.attach_grad(F.clone(feat_lhs))
    rhs = F.attach_grad(F.clone(feat_rhs))
    with F.record_grad():
265
266
267
        e = gsddmm(
            g, msg, lhs, rhs, lhs_target=lhs_target, rhs_target=rhs_target
        )
268
269
270
271
272
273
274
        F.backward(F.reduce_sum(e))
        grad_lhs = F.grad(lhs)
        grad_rhs = F.grad(rhs)

    with F.record_grad():
        g.apply_edges(udf_apply_edges[msg_func])
        if g.number_of_edges() > 0:
275
            e1 = g.edata["m"]
276
            assert F.allclose(e, e1)
277
            print("forward passed")
278
279

            F.backward(F.reduce_sum(e1))
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
            if msg != "copy_rhs":
                assert F.allclose(F.grad(lhs_frame["x"]), grad_lhs)
            if msg != "copy_lhs":
                assert F.allclose(F.grad(rhs_frame["y"]), grad_rhs)
            print("backward passed")

    lhs_frame.pop("x")
    rhs_frame.pop("y")
    if "m" in g.edata:
        g.edata.pop("m")


@pytest.mark.parametrize("g", get_cases(["clique"]))
@pytest.mark.parametrize("norm_by", ["src", "dst"])
@pytest.mark.parametrize("shp", edge_softmax_shapes)
nv-dlasalle's avatar
nv-dlasalle committed
295
@parametrize_idtype
296
297
298
299
300
301
302
303
304
305
306
307
308
def test_edge_softmax(g, norm_by, shp, idtype):
    g = g.astype(idtype).to(F.ctx())
    edata = F.tensor(np.random.rand(g.number_of_edges(), *shp))
    e1 = F.attach_grad(F.clone(edata))

    with F.record_grad():
        score1 = edge_softmax(g, e1, norm_by=norm_by)
        F.backward(F.reduce_sum(score1))
        grad_edata = F.grad(e1)

    with F.record_grad():
        e2 = F.attach_grad(F.clone(edata))
        e2_2d = F.reshape(
309
310
311
312
            e2,
            (g.number_of_src_nodes(), g.number_of_dst_nodes(), *e2.shape[1:]),
        )
        if norm_by == "src":
313
314
            score2 = F.softmax(e2_2d, 1)
            score2 = F.reshape(score2, (-1, *e2.shape[1:]))
315
        if norm_by == "dst":
316
317
318
            score2 = F.softmax(e2_2d, 0)
            score2 = F.reshape(score2, (-1, *e2.shape[1:]))
        assert F.allclose(score1, score2)
319
        print("forward passed")
320
321
322

        F.backward(F.reduce_sum(score2))
        assert F.allclose(F.grad(e2), grad_edata)
323
324
        print("backward passed")

325

326
@pytest.mark.parametrize("reducer", ["sum", "max", "min", "mean"])
327
328
329
330
331
def test_segment_reduce(reducer):
    ctx = F.ctx()
    value = F.tensor(np.random.rand(10, 5))
    v1 = F.attach_grad(F.clone(value))
    v2 = F.attach_grad(F.clone(value))
332
    seglen = F.tensor([2, 3, 0, 4, 1, 0, 0])
333
    u = F.copy_to(F.arange(0, F.shape(value)[0], F.int32), ctx)
334
335
336
337
338
339
340
341
    v = F.repeat(
        F.copy_to(F.arange(0, len(seglen), F.int32), ctx), seglen, dim=0
    )

    num_nodes = {"_U": len(u), "_V": len(seglen)}
    g = dgl.convert.heterograph(
        {("_U", "_E", "_V"): (u, v)}, num_nodes_dict=num_nodes
    )
342
    with F.record_grad():
343
344
        rst1 = gspmm(g, "copy_lhs", reducer, v1, None)
        if reducer in ["max", "min"]:
345
            rst1 = F.replace_inf_with_zero(rst1)
346
347
348
349
350
351
352
        F.backward(F.reduce_sum(rst1))
        grad1 = F.grad(v1)

    with F.record_grad():
        rst2 = segment_reduce(seglen, v2, reducer=reducer)
        F.backward(F.reduce_sum(rst2))
        assert F.allclose(rst1, rst2)
353
        print("forward passed")
354
355
356

        grad2 = F.grad(v2)
        assert F.allclose(grad1, grad2)
357
358
        print("backward passed")

359

360
361
362
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
nv-dlasalle's avatar
nv-dlasalle committed
363
@parametrize_idtype
364
365
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
@pytest.mark.parametrize(
366
367
368
    "dtype, tol",
    [(torch.float16, 1e-2), (torch.bfloat16, 1e-2),
     (torch.float32, 3e-3), (torch.float64, 1e-4)],
369
)
370
def test_segment_mm(idtype, feat_size, dtype, tol):
371
    if F._default_context_str == "cpu" and dtype in (torch.float16, torch.bfloat16):
372
        pytest.skip(
373
374
375
            "Only support float32 and float64 on CPU."
        )
    if F._default_context_str == "gpu" \
376
        and version.parse(torch.version.cuda) < version.parse("11.0") \
377
378
379
        and dtype == torch.bfloat16:
        pytest.skip(
            "BF16 requires CUDA >= 11.0."
380
        )
381
382
    dev = F.ctx()
    # input
383
    a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
384
    a.requires_grad_()
385
386
387
388
389
    b = (
        torch.tensor(np.random.rand(10, feat_size, feat_size + 1))
        .to(dev)
        .to(dtype)
    )
390
    b.requires_grad_()
391
    seglen_a = torch.tensor([10, 15, 8, 0, 1, 9, 18, 24, 15, 0]).to(idtype)
392
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
393
394
395
396
397
398
399
400
401
    # compute
    c = dgl.ops.segment_mm(a, b, seglen_a)
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = []
    off = 0
    for i, l in enumerate(seglen_a):
402
        c_t.append(a[off : off + l] @ b[i])
403
        off += l
404
    c_t = torch.cat(c_t).to(dtype)
405
406
407
408
409
410
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

411
412
413
    assert torch.allclose(c, c_t, atol=tol, rtol=tol)
    assert torch.allclose(da, da_t, atol=tol, rtol=tol)
    assert torch.allclose(db, db_t, atol=tol, rtol=tol)
414

415
416
417
418
419

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
420
421
422
423
424
425
426
427
428
@pytest.mark.parametrize(
    "dtype, tol",
    [(torch.float16, 1e-2), (torch.bfloat16, 2e-2),
     (torch.float32, 3e-3), (torch.float64, 1e-4)]
)
def test_gather_mm_idx_b(feat_size, dtype, tol):
    if F._default_context_str == "cpu" and dtype in (torch.float16, torch.bfloat16):
        pytest.skip("Only support float32 and float64 on CPU.")
    if F._default_context_str == "gpu" \
429
        and version.parse(torch.version.cuda) < version.parse("11.0") \
430
431
        and dtype == torch.bfloat16:
        pytest.skip("BF16 requires CUDA >= 11.0.")
432

433
434
    dev = F.ctx()
    # input
435
    a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
436
    a.requires_grad_()
437
    b = torch.tensor(np.random.rand(10, feat_size, feat_size + 1)).to(dev).to(dtype)
438
439
    b.requires_grad_()
    idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev).long()
440
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
441
442
443
444
445
446
447
448
449
450
451
452
453
    # compute
    c = dgl.ops.gather_mm(a, b, idx_b=idx)
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = torch.bmm(a.unsqueeze(1), b[idx]).squeeze(1)
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

454
455
456
    assert torch.allclose(c, c_t, atol=tol, rtol=tol)
    assert torch.allclose(da, da_t, atol=tol, rtol=tol)
    assert torch.allclose(db, db_t, atol=tol, rtol=tol)
457

458
459
460
461

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
nv-dlasalle's avatar
nv-dlasalle committed
462
@parametrize_idtype
463
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
464
465
466
def _test_gather_mm_idx_a(idtype, feat_size):
    # TODO(minjie): currently disabled due to bugs in the CUDA kernel. Need to fix it later.
    import torch
467

468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    dev = F.ctx()
    # input
    a = torch.tensor(np.random.rand(10, feat_size)).to(dev)
    a.requires_grad_()
    b = torch.tensor(np.random.rand(100, feat_size, feat_size + 1)).to(dev)
    b.requires_grad_()
    idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev)
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev)
    # compute
    c = dgl.ops.gather_mm(a, b, idx_a=idx)
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = torch.bmm(a[idx].unsqueeze(1), b).squeeze(1)
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

    assert torch.allclose(c, c_t, atol=1e-4, rtol=1e-4)
    assert torch.allclose(da, da_t, atol=1e-4, rtol=1e-4)
    assert torch.allclose(db, db_t, atol=1e-4, rtol=1e-4)
492

493
494
495
496
497
498
499

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@unittest.skipIf(
    F._default_context_str == "gpu", reason="Libxsmm only fit in CPU."
)
500
501
def test_use_libxsmm_switch():
    import torch
502

503
504
505
506
507
508
509
510
511
512
513
514
    g = dgl.graph(([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]))
    x = torch.ones(3, 2, requires_grad=True)
    y = torch.arange(1, 13).float().view(6, 2).requires_grad_()

    assert dgl.is_libxsmm_enabled()
    dgl.ops.u_mul_e_sum(g, x, y)
    dgl.use_libxsmm(False)
    assert ~dgl.is_libxsmm_enabled()
    dgl.ops.u_mul_e_sum(g, x, y)
    dgl.use_libxsmm(True)
    assert dgl.is_libxsmm_enabled()
    dgl.ops.u_mul_e_sum(g, x, y)