test_ops.py 15.1 KB
Newer Older
1
import random
2
3
import unittest

4
import backend as F
Xin Yao's avatar
Xin Yao committed
5
6

import dgl
7
8
import numpy as np
import pytest
9
import torch
10
from dgl.ops import gather_mm, gsddmm, gspmm, segment_reduce
11
12
from utils import parametrize_idtype
from utils.graph_cases import get_cases
13

14
15
16
17
18
19
# Set seeds to make tests fully reproducible.
SEED = 12345  # random.randint(1, 99999)
random.seed(SEED)
np.random.seed(SEED)
dgl.seed(SEED)
F.seed(SEED)
20
21

udf_msg = {
22
23
24
25
26
27
    "add": lambda edges: {"m": edges.src["x"] + edges.data["w"]},
    "sub": lambda edges: {"m": edges.src["x"] - edges.data["w"]},
    "mul": lambda edges: {"m": edges.src["x"] * edges.data["w"]},
    "div": lambda edges: {"m": edges.src["x"] / edges.data["w"]},
    "copy_lhs": lambda edges: {"m": edges.src["x"]},
    "copy_rhs": lambda edges: {"m": edges.data["w"]},
28
29
}

30

31
def select(target, src, edge, dst):
32
    if target == "u":
33
        return src
34
    elif target == "v":
35
        return dst
36
    elif target == "e":
37
38
        return edge

39

40
def binary_op(msg, x, y):
41
    if msg == "add":
42
        return x + y
43
    elif msg == "sub":
44
        return x - y
45
    elif msg == "mul":
46
        return x * y
47
    elif msg == "div":
48
        return x / y
49
    elif msg == "dot":
50
        return F.sum(x * y, -1, keepdims=True)
51
    elif msg == "copy_lhs":
52
        return x
53
    elif msg == "copy_rhs":
54
55
        return y

56

57
58
59
def edge_func(lhs_target, rhs_target, msg):
    def foo(edges):
        return {
60
            "m": binary_op(
61
                msg,
62
63
                select(lhs_target, edges.src, edges.data, edges.dst)["x"],
                select(rhs_target, edges.src, edges.data, edges.dst)["y"],
64
65
            )
        }
66

67
68
    return foo

69

70
udf_apply_edges = {
71
72
73
74
75
76
77
78
    lhs_target
    + "_"
    + msg
    + "_"
    + rhs_target: edge_func(lhs_target, rhs_target, msg)
    for lhs_target in ["u", "v", "e"]
    for rhs_target in ["u", "v", "e"]
    for msg in ["add", "sub", "mul", "div", "dot", "copy_lhs", "copy_rhs"]
79
80
81
}

udf_reduce = {
82
83
84
    "sum": lambda nodes: {"v": F.sum(nodes.mailbox["m"], 1)},
    "min": lambda nodes: {"v": F.min(nodes.mailbox["m"], 1)},
    "max": lambda nodes: {"v": F.max(nodes.mailbox["m"], 1)},
85
86
87
}

graphs = [
88
    #    dgl.rand_graph(30, 0),
89
    dgl.rand_graph(30, 100),
90
    dgl.rand_bipartite("_U", "_E", "_V", 30, 40, 300),
91
92
93
94
95
]

spmm_shapes = [
    ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),
    ((3, 3), (1, 3)),
96
97
    ((1,), (3,)),
    ((3,), (1,)),
98
    ((1,), (1,)),
99
    ((), ()),
100
101
102
103
104
105
106
]

sddmm_shapes = [
    ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),
    ((5, 3, 1, 7), (1, 3, 7, 7)),
    ((1, 3, 3), (4, 1, 3)),
    ((3,), (3,)),
107
    ((1,), (1,)),
108
109
]

110
111
112
113
114
115
116

@pytest.mark.parametrize("g", graphs)
@pytest.mark.parametrize("shp", spmm_shapes)
@pytest.mark.parametrize(
    "msg", ["add", "sub", "mul", "div", "copy_lhs", "copy_rhs"]
)
@pytest.mark.parametrize("reducer", ["sum", "min", "max"])
nv-dlasalle's avatar
nv-dlasalle committed
117
@parametrize_idtype
118
119
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_spmm(idtype, dtype, g, shp, msg, reducer):
120
    g = g.astype(idtype).to(F.ctx())
121
    print(g)
122
    print(g.idtype)
123

124
125
126
127
128
129
    hu = F.tensor(
        np.random.rand(*((g.number_of_src_nodes(),) + shp[0])).astype(dtype) + 1
    )
    he = F.tensor(
        np.random.rand(*((g.num_edges(),) + shp[1])).astype(dtype) + 1
    )
130
    print("u shape: {}, e shape: {}".format(F.shape(hu), F.shape(he)))
131

132
133
134
    g.srcdata["x"] = F.attach_grad(F.clone(hu))
    g.edata["w"] = F.attach_grad(F.clone(he))
    print("SpMM(message func: {}, reduce func: {})".format(msg, reducer))
135
136
137
138
139

    u = F.attach_grad(F.clone(hu))
    e = F.attach_grad(F.clone(he))
    with F.record_grad():
        v = gspmm(g, msg, reducer, u, e)
140
        if reducer in ["max", "min"]:
141
            v = F.replace_inf_with_zero(v)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
142
        if g.num_edges() > 0:
143
            F.backward(F.reduce_sum(v))
144
            if msg != "copy_rhs":
145
                grad_u = F.grad(u)
146
            if msg != "copy_lhs":
147
148
149
150
                grad_e = F.grad(e)

    with F.record_grad():
        g.update_all(udf_msg[msg], udf_reduce[reducer])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
151
        if g.num_edges() > 0:
152
            v1 = g.dstdata["v"]
153
            assert F.allclose(v, v1)
154
            print("forward passed")
155
156

            F.backward(F.reduce_sum(v1))
157
158
159
160
161
162
163
164
            if msg != "copy_rhs":
                if reducer in [
                    "min",
                    "max",
                ]:  # there might be some numerical errors
                    rate = F.reduce_sum(
                        F.abs(F.grad(g.srcdata["x"]) - grad_u)
                    ) / F.reduce_sum(F.abs(grad_u))
Zihao Ye's avatar
Zihao Ye committed
165
                    assert F.as_scalar(rate) < 1e-2, rate
166
                else:
167
168
169
170
171
172
                    assert F.allclose(F.grad(g.srcdata["x"]), grad_u)
            if msg != "copy_lhs":
                if reducer in ["min", "max"]:
                    rate = F.reduce_sum(
                        F.abs(F.grad(g.edata["w"]) - grad_e)
                    ) / F.reduce_sum(F.abs(grad_e))
Zihao Ye's avatar
Zihao Ye committed
173
                    assert F.as_scalar(rate) < 1e-2, rate
174
                else:
175
176
177
178
179
180
181
182
183
                    assert F.allclose(F.grad(g.edata["w"]), grad_e)
            print("backward passed")

    g.srcdata.pop("x")
    g.edata.pop("w")
    if "v" in g.dstdata:
        g.dstdata.pop("v")


184
185
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch",
Xin Yao's avatar
Xin Yao committed
186
    reason="Only support PyTorch for now.",
187
188
189
190
)
@parametrize_idtype
@pytest.mark.parametrize(
    "dtype, rtol, atol",
Xin Yao's avatar
Xin Yao committed
191
    [(torch.float16, 1e-3, 0.5), (torch.bfloat16, 4e-3, 2.0)],
192
193
)
def test_half_spmm(idtype, dtype, rtol, atol):
194
195
196
197
198
199
200
    if F._default_context_str == "cpu" and dtype == torch.float16:
        pytest.skip("float16 is not supported on CPU.")
    if (
        F._default_context_str == "gpu"
        and dtype == torch.bfloat16
        and not torch.cuda.is_bf16_supported()
    ):
Xin Yao's avatar
Xin Yao committed
201
        pytest.skip("BF16 is not supported.")
202
203

    # make sure the spmm result is < 512 to match the rtol/atol we set.
Xin Yao's avatar
Xin Yao committed
204
205
206
207
208
    g = dgl.graph(
        (torch.arange(900), torch.tensor([0] * 900)),
        idtype=idtype,
        device=F.ctx(),
    )
209
    feat_fp32 = torch.rand((g.num_src_nodes(), 32)).to(F.ctx())
210
211
212
    feat_half = feat_fp32.to(dtype)

    # test SpMMCSR
Xin Yao's avatar
Xin Yao committed
213
    g = g.formats(["csc"])
214
215
216
217
218
219
220
221
222
223
224
225
    res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0]
    res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float()
    assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol)

    # test SpMMCOO
    # TODO(Xin): half-precision SpMMCoo is temporally disabled.
    # g = g.formats(['coo'])
    # res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0]
    # res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float()
    # assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol)


226
227
228
229
230
231
232
@pytest.mark.parametrize("g", graphs)
@pytest.mark.parametrize("shp", sddmm_shapes)
@pytest.mark.parametrize("lhs_target", ["u", "v", "e"])
@pytest.mark.parametrize("rhs_target", ["u", "v", "e"])
@pytest.mark.parametrize(
    "msg", ["add", "sub", "mul", "div", "dot", "copy_lhs", "copy_rhs"]
)
nv-dlasalle's avatar
nv-dlasalle committed
233
@parametrize_idtype
234
235
236
237
def test_sddmm(g, shp, lhs_target, rhs_target, msg, idtype):
    if lhs_target == rhs_target:
        return
    g = g.astype(idtype).to(F.ctx())
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
238
    if dgl.backend.backend_name == "mxnet" and g.num_edges() == 0:
239
        pytest.skip()  # mxnet do not support zero shape tensor
240
    print(g)
241
242
243
244
245
    print(g.idtype)

    len_lhs = select(
        lhs_target,
        g.number_of_src_nodes(),
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
246
        g.num_edges(),
247
248
        g.number_of_dst_nodes(),
    )
249
250
251
252
    lhs_shp = (len_lhs,) + shp[0]
    len_rhs = select(
        rhs_target,
        g.number_of_src_nodes(),
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
253
        g.num_edges(),
254
255
        g.number_of_dst_nodes(),
    )
256
257
258
    rhs_shp = (len_rhs,) + shp[1]
    feat_lhs = F.tensor(np.random.rand(*lhs_shp) + 1)
    feat_rhs = F.tensor(np.random.rand(*rhs_shp) + 1)
259
260
261
262
263
264
265
266
267
268
269
270
    print(
        "lhs shape: {}, rhs shape: {}".format(
            F.shape(feat_lhs), F.shape(feat_rhs)
        )
    )

    lhs_frame = select(lhs_target, g.srcdata, g.edata, g.dstdata)
    rhs_frame = select(rhs_target, g.srcdata, g.edata, g.dstdata)
    lhs_frame["x"] = F.attach_grad(F.clone(feat_lhs))
    rhs_frame["y"] = F.attach_grad(F.clone(feat_rhs))
    msg_func = lhs_target + "_" + msg + "_" + rhs_target
    print("SDDMM(message func: {})".format(msg_func))
271
272
273
274

    lhs = F.attach_grad(F.clone(feat_lhs))
    rhs = F.attach_grad(F.clone(feat_rhs))
    with F.record_grad():
275
276
277
        e = gsddmm(
            g, msg, lhs, rhs, lhs_target=lhs_target, rhs_target=rhs_target
        )
278
279
280
281
282
283
        F.backward(F.reduce_sum(e))
        grad_lhs = F.grad(lhs)
        grad_rhs = F.grad(rhs)

    with F.record_grad():
        g.apply_edges(udf_apply_edges[msg_func])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
284
        if g.num_edges() > 0:
285
            e1 = g.edata["m"]
286
            assert F.allclose(e, e1)
287
            print("forward passed")
288
289

            F.backward(F.reduce_sum(e1))
290
291
292
293
294
295
296
297
298
299
300
301
302
            if msg != "copy_rhs":
                assert F.allclose(F.grad(lhs_frame["x"]), grad_lhs)
            if msg != "copy_lhs":
                assert F.allclose(F.grad(rhs_frame["y"]), grad_rhs)
            print("backward passed")

    lhs_frame.pop("x")
    rhs_frame.pop("y")
    if "m" in g.edata:
        g.edata.pop("m")


@pytest.mark.parametrize("reducer", ["sum", "max", "min", "mean"])
303
304
305
306
307
def test_segment_reduce(reducer):
    ctx = F.ctx()
    value = F.tensor(np.random.rand(10, 5))
    v1 = F.attach_grad(F.clone(value))
    v2 = F.attach_grad(F.clone(value))
308
    seglen = F.tensor([2, 3, 0, 4, 1, 0, 0])
309
    u = F.copy_to(F.arange(0, F.shape(value)[0], F.int32), ctx)
310
311
312
313
314
315
316
317
    v = F.repeat(
        F.copy_to(F.arange(0, len(seglen), F.int32), ctx), seglen, dim=0
    )

    num_nodes = {"_U": len(u), "_V": len(seglen)}
    g = dgl.convert.heterograph(
        {("_U", "_E", "_V"): (u, v)}, num_nodes_dict=num_nodes
    )
318
    with F.record_grad():
319
320
        rst1 = gspmm(g, "copy_lhs", reducer, v1, None)
        if reducer in ["max", "min"]:
321
            rst1 = F.replace_inf_with_zero(rst1)
322
323
324
325
326
327
328
        F.backward(F.reduce_sum(rst1))
        grad1 = F.grad(v1)

    with F.record_grad():
        rst2 = segment_reduce(seglen, v2, reducer=reducer)
        F.backward(F.reduce_sum(rst2))
        assert F.allclose(rst1, rst2)
329
        print("forward passed")
330
331
332

        grad2 = F.grad(v2)
        assert F.allclose(grad1, grad2)
333
334
        print("backward passed")

335

336
337
338
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
nv-dlasalle's avatar
nv-dlasalle committed
339
@parametrize_idtype
340
341
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
@pytest.mark.parametrize(
342
    "dtype, tol",
Xin Yao's avatar
Xin Yao committed
343
344
345
346
347
348
    [
        (torch.float16, 1e-2),
        (torch.bfloat16, 1e-2),
        (torch.float32, 3e-3),
        (torch.float64, 1e-4),
    ],
349
)
350
def test_segment_mm(idtype, feat_size, dtype, tol):
351
352
    if F._default_context_str == "cpu" and dtype == torch.float16:
        pytest.skip("float16 is not supported on CPU.")
Xin Yao's avatar
Xin Yao committed
353
354
355
356
357
358
    if (
        F._default_context_str == "gpu"
        and dtype == torch.bfloat16
        and not torch.cuda.is_bf16_supported()
    ):
        pytest.skip("BF16 is not supported.")
359
360
    dev = F.ctx()
    # input
361
    a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
362
    a.requires_grad_()
363
364
365
366
367
    b = (
        torch.tensor(np.random.rand(10, feat_size, feat_size + 1))
        .to(dev)
        .to(dtype)
    )
368
    b.requires_grad_()
369
    seglen_a = torch.tensor([10, 15, 8, 0, 1, 9, 18, 24, 15, 0]).to(idtype)
370
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
371
372
373
374
375
376
377
378
379
    # compute
    c = dgl.ops.segment_mm(a, b, seglen_a)
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = []
    off = 0
    for i, l in enumerate(seglen_a):
380
        c_t.append(a[off : off + l] @ b[i])
381
        off += l
382
    c_t = torch.cat(c_t).to(dtype)
383
384
385
386
387
388
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

389
390
391
    assert torch.allclose(c, c_t, atol=tol, rtol=tol)
    assert torch.allclose(da, da_t, atol=tol, rtol=tol)
    assert torch.allclose(db, db_t, atol=tol, rtol=tol)
392

393
394
395
396
397

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
398
399
@pytest.mark.parametrize(
    "dtype, tol",
Xin Yao's avatar
Xin Yao committed
400
401
402
403
404
405
    [
        (torch.float16, 1e-2),
        (torch.bfloat16, 2e-2),
        (torch.float32, 3e-3),
        (torch.float64, 1e-4),
    ],
406
407
)
def test_gather_mm_idx_b(feat_size, dtype, tol):
408
409
    if F._default_context_str == "cpu" and dtype == torch.float16:
        pytest.skip("float16 is not supported on CPU.")
Xin Yao's avatar
Xin Yao committed
410
411
412
413
414
415
    if (
        F._default_context_str == "gpu"
        and dtype == torch.bfloat16
        and not torch.cuda.is_bf16_supported()
    ):
        pytest.skip("BF16 is not supported.")
416

417
418
419
420
421
422
423
424
425
426
    if (
        F._default_context_str == "gpu"
        and dtype == torch.float16
        and torch.cuda.get_device_capability() < (7, 0)
    ):
        pytest.skip(
            f"FP16 is not supported for atomic operations on GPU with "
            f"cuda capability ({torch.cuda.get_device_capability()})."
        )

427
428
    dev = F.ctx()
    # input
429
    a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
430
    a.requires_grad_()
Xin Yao's avatar
Xin Yao committed
431
432
433
434
435
    b = (
        torch.tensor(np.random.rand(10, feat_size, feat_size + 1))
        .to(dev)
        .to(dtype)
    )
436
437
    b.requires_grad_()
    idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev).long()
438
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
439
    # compute
440
    c = gather_mm(a, b, idx_b=idx)
441
442
443
444
445
446
447
448
449
450
451
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = torch.bmm(a.unsqueeze(1), b[idx]).squeeze(1)
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

452
453
454
    assert torch.allclose(c, c_t, atol=tol, rtol=tol)
    assert torch.allclose(da, da_t, atol=tol, rtol=tol)
    assert torch.allclose(db, db_t, atol=tol, rtol=tol)
455

456
457
458
459

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
nv-dlasalle's avatar
nv-dlasalle committed
460
@parametrize_idtype
461
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
462
463
464
def _test_gather_mm_idx_a(idtype, feat_size):
    # TODO(minjie): currently disabled due to bugs in the CUDA kernel. Need to fix it later.
    import torch
465

466
467
468
469
470
471
472
473
474
    dev = F.ctx()
    # input
    a = torch.tensor(np.random.rand(10, feat_size)).to(dev)
    a.requires_grad_()
    b = torch.tensor(np.random.rand(100, feat_size, feat_size + 1)).to(dev)
    b.requires_grad_()
    idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev)
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev)
    # compute
475
    c = gather_mm(a, b, idx_a=idx)
476
477
478
479
480
481
482
483
484
485
486
487
488
489
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = torch.bmm(a[idx].unsqueeze(1), b).squeeze(1)
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

    assert torch.allclose(c, c_t, atol=1e-4, rtol=1e-4)
    assert torch.allclose(da, da_t, atol=1e-4, rtol=1e-4)
    assert torch.allclose(db, db_t, atol=1e-4, rtol=1e-4)
490

491
492
493
494
495
496
497

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@unittest.skipIf(
    F._default_context_str == "gpu", reason="Libxsmm only fit in CPU."
)
498
499
def test_use_libxsmm_switch():
    import torch
500

501
502
503
504
505
506
507
508
509
510
    g = dgl.graph(([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]))
    x = torch.ones(3, 2, requires_grad=True)
    y = torch.arange(1, 13).float().view(6, 2).requires_grad_()

    dgl.use_libxsmm(False)
    assert ~dgl.is_libxsmm_enabled()
    dgl.ops.u_mul_e_sum(g, x, y)
    dgl.use_libxsmm(True)
    assert dgl.is_libxsmm_enabled()
    dgl.ops.u_mul_e_sum(g, x, y)