test_sparse.py 13.8 KB
Newer Older
1
import random
2
3
import unittest

4
import backend as F
5
6
7
import networkx as nx
import numpy as np
import pytest
8
import torch
9
10
11
12
13
from test_utils import parametrize_idtype
from test_utils.graph_cases import get_cases

import dgl
from dgl.ops import edge_softmax, gsddmm, gspmm, segment_reduce
14

15
random.seed(42)
16
17
18
np.random.seed(42)

udf_msg = {
19
20
21
22
23
24
    "add": lambda edges: {"m": edges.src["x"] + edges.data["w"]},
    "sub": lambda edges: {"m": edges.src["x"] - edges.data["w"]},
    "mul": lambda edges: {"m": edges.src["x"] * edges.data["w"]},
    "div": lambda edges: {"m": edges.src["x"] / edges.data["w"]},
    "copy_lhs": lambda edges: {"m": edges.src["x"]},
    "copy_rhs": lambda edges: {"m": edges.data["w"]},
25
26
}

27

28
def select(target, src, edge, dst):
29
    if target == "u":
30
        return src
31
    elif target == "v":
32
        return dst
33
    elif target == "e":
34
35
        return edge

36

37
def binary_op(msg, x, y):
38
    if msg == "add":
39
        return x + y
40
    elif msg == "sub":
41
        return x - y
42
    elif msg == "mul":
43
        return x * y
44
    elif msg == "div":
45
        return x / y
46
    elif msg == "dot":
47
        return F.sum(x * y, -1, keepdims=True)
48
    elif msg == "copy_lhs":
49
        return x
50
    elif msg == "copy_rhs":
51
52
        return y

53

54
55
56
def edge_func(lhs_target, rhs_target, msg):
    def foo(edges):
        return {
57
            "m": binary_op(
58
                msg,
59
60
                select(lhs_target, edges.src, edges.data, edges.dst)["x"],
                select(rhs_target, edges.src, edges.data, edges.dst)["y"],
61
62
            )
        }
63

64
65
    return foo

66

67
udf_apply_edges = {
68
69
70
71
72
73
74
75
    lhs_target
    + "_"
    + msg
    + "_"
    + rhs_target: edge_func(lhs_target, rhs_target, msg)
    for lhs_target in ["u", "v", "e"]
    for rhs_target in ["u", "v", "e"]
    for msg in ["add", "sub", "mul", "div", "dot", "copy_lhs", "copy_rhs"]
76
77
78
}

udf_reduce = {
79
80
81
    "sum": lambda nodes: {"v": F.sum(nodes.mailbox["m"], 1)},
    "min": lambda nodes: {"v": F.min(nodes.mailbox["m"], 1)},
    "max": lambda nodes: {"v": F.max(nodes.mailbox["m"], 1)},
82
83
84
}

graphs = [
85
    #    dgl.rand_graph(30, 0),
86
    dgl.rand_graph(30, 100),
87
    dgl.rand_bipartite("_U", "_E", "_V", 30, 40, 300),
88
89
90
91
92
]

spmm_shapes = [
    ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),
    ((3, 3), (1, 3)),
93
94
    ((1,), (3,)),
    ((3,), (1,)),
95
    ((1,), (1,)),
96
    ((), ()),
97
98
99
100
101
102
103
]

sddmm_shapes = [
    ((1, 2, 1, 3, 1), (4, 1, 3, 1, 1)),
    ((5, 3, 1, 7), (1, 3, 7, 7)),
    ((1, 3, 3), (4, 1, 3)),
    ((3,), (3,)),
104
    ((1,), (1,)),
105
106
]

107
edge_softmax_shapes = [(1,), (1, 3), (3, 4, 5)]
108

109
110
111
112
113
114
115

@pytest.mark.parametrize("g", graphs)
@pytest.mark.parametrize("shp", spmm_shapes)
@pytest.mark.parametrize(
    "msg", ["add", "sub", "mul", "div", "copy_lhs", "copy_rhs"]
)
@pytest.mark.parametrize("reducer", ["sum", "min", "max"])
nv-dlasalle's avatar
nv-dlasalle committed
116
@parametrize_idtype
117
118
def test_spmm(idtype, g, shp, msg, reducer):
    g = g.astype(idtype).to(F.ctx())
119
    print(g)
120
    print(g.idtype)
121

122
123
    hu = F.tensor(np.random.rand(*((g.number_of_src_nodes(),) + shp[0])) + 1)
    he = F.tensor(np.random.rand(*((g.number_of_edges(),) + shp[1])) + 1)
124
    print("u shape: {}, e shape: {}".format(F.shape(hu), F.shape(he)))
125

126
127
128
    g.srcdata["x"] = F.attach_grad(F.clone(hu))
    g.edata["w"] = F.attach_grad(F.clone(he))
    print("SpMM(message func: {}, reduce func: {})".format(msg, reducer))
129
130
131
132
133

    u = F.attach_grad(F.clone(hu))
    e = F.attach_grad(F.clone(he))
    with F.record_grad():
        v = gspmm(g, msg, reducer, u, e)
134
        if reducer in ["max", "min"]:
135
            v = F.replace_inf_with_zero(v)
136
137
        if g.number_of_edges() > 0:
            F.backward(F.reduce_sum(v))
138
            if msg != "copy_rhs":
139
                grad_u = F.grad(u)
140
            if msg != "copy_lhs":
141
142
143
144
145
                grad_e = F.grad(e)

    with F.record_grad():
        g.update_all(udf_msg[msg], udf_reduce[reducer])
        if g.number_of_edges() > 0:
146
            v1 = g.dstdata["v"]
147
            assert F.allclose(v, v1)
148
            print("forward passed")
149
150

            F.backward(F.reduce_sum(v1))
151
152
153
154
155
156
157
158
            if msg != "copy_rhs":
                if reducer in [
                    "min",
                    "max",
                ]:  # there might be some numerical errors
                    rate = F.reduce_sum(
                        F.abs(F.grad(g.srcdata["x"]) - grad_u)
                    ) / F.reduce_sum(F.abs(grad_u))
Zihao Ye's avatar
Zihao Ye committed
159
                    assert F.as_scalar(rate) < 1e-2, rate
160
                else:
161
162
163
164
165
166
                    assert F.allclose(F.grad(g.srcdata["x"]), grad_u)
            if msg != "copy_lhs":
                if reducer in ["min", "max"]:
                    rate = F.reduce_sum(
                        F.abs(F.grad(g.edata["w"]) - grad_e)
                    ) / F.reduce_sum(F.abs(grad_e))
Zihao Ye's avatar
Zihao Ye committed
167
                    assert F.as_scalar(rate) < 1e-2, rate
168
                else:
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
                    assert F.allclose(F.grad(g.edata["w"]), grad_e)
            print("backward passed")

    g.srcdata.pop("x")
    g.edata.pop("w")
    if "v" in g.dstdata:
        g.dstdata.pop("v")


@pytest.mark.parametrize("g", graphs)
@pytest.mark.parametrize("shp", sddmm_shapes)
@pytest.mark.parametrize("lhs_target", ["u", "v", "e"])
@pytest.mark.parametrize("rhs_target", ["u", "v", "e"])
@pytest.mark.parametrize(
    "msg", ["add", "sub", "mul", "div", "dot", "copy_lhs", "copy_rhs"]
)
nv-dlasalle's avatar
nv-dlasalle committed
185
@parametrize_idtype
186
187
188
189
def test_sddmm(g, shp, lhs_target, rhs_target, msg, idtype):
    if lhs_target == rhs_target:
        return
    g = g.astype(idtype).to(F.ctx())
190
191
    if dgl.backend.backend_name == "mxnet" and g.number_of_edges() == 0:
        pytest.skip()  # mxnet do not support zero shape tensor
192
    print(g)
193
194
195
196
197
198
    print(g.idtype)

    len_lhs = select(
        lhs_target,
        g.number_of_src_nodes(),
        g.number_of_edges(),
199
200
        g.number_of_dst_nodes(),
    )
201
202
203
204
205
    lhs_shp = (len_lhs,) + shp[0]
    len_rhs = select(
        rhs_target,
        g.number_of_src_nodes(),
        g.number_of_edges(),
206
207
        g.number_of_dst_nodes(),
    )
208
209
210
    rhs_shp = (len_rhs,) + shp[1]
    feat_lhs = F.tensor(np.random.rand(*lhs_shp) + 1)
    feat_rhs = F.tensor(np.random.rand(*rhs_shp) + 1)
211
212
213
214
215
216
217
218
219
220
221
222
    print(
        "lhs shape: {}, rhs shape: {}".format(
            F.shape(feat_lhs), F.shape(feat_rhs)
        )
    )

    lhs_frame = select(lhs_target, g.srcdata, g.edata, g.dstdata)
    rhs_frame = select(rhs_target, g.srcdata, g.edata, g.dstdata)
    lhs_frame["x"] = F.attach_grad(F.clone(feat_lhs))
    rhs_frame["y"] = F.attach_grad(F.clone(feat_rhs))
    msg_func = lhs_target + "_" + msg + "_" + rhs_target
    print("SDDMM(message func: {})".format(msg_func))
223
224
225
226

    lhs = F.attach_grad(F.clone(feat_lhs))
    rhs = F.attach_grad(F.clone(feat_rhs))
    with F.record_grad():
227
228
229
        e = gsddmm(
            g, msg, lhs, rhs, lhs_target=lhs_target, rhs_target=rhs_target
        )
230
231
232
233
234
235
236
        F.backward(F.reduce_sum(e))
        grad_lhs = F.grad(lhs)
        grad_rhs = F.grad(rhs)

    with F.record_grad():
        g.apply_edges(udf_apply_edges[msg_func])
        if g.number_of_edges() > 0:
237
            e1 = g.edata["m"]
238
            assert F.allclose(e, e1)
239
            print("forward passed")
240
241

            F.backward(F.reduce_sum(e1))
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
            if msg != "copy_rhs":
                assert F.allclose(F.grad(lhs_frame["x"]), grad_lhs)
            if msg != "copy_lhs":
                assert F.allclose(F.grad(rhs_frame["y"]), grad_rhs)
            print("backward passed")

    lhs_frame.pop("x")
    rhs_frame.pop("y")
    if "m" in g.edata:
        g.edata.pop("m")


@pytest.mark.parametrize("g", get_cases(["clique"]))
@pytest.mark.parametrize("norm_by", ["src", "dst"])
@pytest.mark.parametrize("shp", edge_softmax_shapes)
nv-dlasalle's avatar
nv-dlasalle committed
257
@parametrize_idtype
258
259
260
261
262
263
264
265
266
267
268
269
270
def test_edge_softmax(g, norm_by, shp, idtype):
    g = g.astype(idtype).to(F.ctx())
    edata = F.tensor(np.random.rand(g.number_of_edges(), *shp))
    e1 = F.attach_grad(F.clone(edata))

    with F.record_grad():
        score1 = edge_softmax(g, e1, norm_by=norm_by)
        F.backward(F.reduce_sum(score1))
        grad_edata = F.grad(e1)

    with F.record_grad():
        e2 = F.attach_grad(F.clone(edata))
        e2_2d = F.reshape(
271
272
273
274
            e2,
            (g.number_of_src_nodes(), g.number_of_dst_nodes(), *e2.shape[1:]),
        )
        if norm_by == "src":
275
276
            score2 = F.softmax(e2_2d, 1)
            score2 = F.reshape(score2, (-1, *e2.shape[1:]))
277
        if norm_by == "dst":
278
279
280
            score2 = F.softmax(e2_2d, 0)
            score2 = F.reshape(score2, (-1, *e2.shape[1:]))
        assert F.allclose(score1, score2)
281
        print("forward passed")
282
283
284

        F.backward(F.reduce_sum(score2))
        assert F.allclose(F.grad(e2), grad_edata)
285
286
        print("backward passed")

287

288
@pytest.mark.parametrize("reducer", ["sum", "max", "min", "mean"])
289
290
291
292
293
def test_segment_reduce(reducer):
    ctx = F.ctx()
    value = F.tensor(np.random.rand(10, 5))
    v1 = F.attach_grad(F.clone(value))
    v2 = F.attach_grad(F.clone(value))
294
    seglen = F.tensor([2, 3, 0, 4, 1, 0, 0])
295
    u = F.copy_to(F.arange(0, F.shape(value)[0], F.int32), ctx)
296
297
298
299
300
301
302
303
    v = F.repeat(
        F.copy_to(F.arange(0, len(seglen), F.int32), ctx), seglen, dim=0
    )

    num_nodes = {"_U": len(u), "_V": len(seglen)}
    g = dgl.convert.heterograph(
        {("_U", "_E", "_V"): (u, v)}, num_nodes_dict=num_nodes
    )
304
    with F.record_grad():
305
306
        rst1 = gspmm(g, "copy_lhs", reducer, v1, None)
        if reducer in ["max", "min"]:
307
            rst1 = F.replace_inf_with_zero(rst1)
308
309
310
311
312
313
314
        F.backward(F.reduce_sum(rst1))
        grad1 = F.grad(v1)

    with F.record_grad():
        rst2 = segment_reduce(seglen, v2, reducer=reducer)
        F.backward(F.reduce_sum(rst2))
        assert F.allclose(rst1, rst2)
315
        print("forward passed")
316
317
318

        grad2 = F.grad(v2)
        assert F.allclose(grad1, grad2)
319
320
        print("backward passed")

321

322
323
324
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
nv-dlasalle's avatar
nv-dlasalle committed
325
@parametrize_idtype
326
327
328
329
330
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
@pytest.mark.parametrize(
    "dtype,tol",
    [(torch.float16, 1e-2), (torch.float32, 3e-3), (torch.float64, 1e-4)],
)
331
def test_segment_mm(idtype, feat_size, dtype, tol):
332
333
334
335
    if F._default_context_str == "cpu" and dtype == torch.float16:
        pytest.skip(
            "fp16 support for CPU linalg functions has been removed in PyTorch."
        )
336
337
    dev = F.ctx()
    # input
338
    a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
339
    a.requires_grad_()
340
341
342
343
344
    b = (
        torch.tensor(np.random.rand(10, feat_size, feat_size + 1))
        .to(dev)
        .to(dtype)
    )
345
346
    b.requires_grad_()
    seglen_a = torch.tensor([10, 15, 8, 0, 1, 9, 18, 24, 15, 0])
347
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
348
349
350
351
352
353
354
355
356
    # compute
    c = dgl.ops.segment_mm(a, b, seglen_a)
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = []
    off = 0
    for i, l in enumerate(seglen_a):
357
        c_t.append(a[off : off + l] @ b[i])
358
        off += l
359
    c_t = torch.cat(c_t).to(dtype)
360
361
362
363
364
365
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

366
367
368
    assert torch.allclose(c, c_t, atol=tol, rtol=tol)
    assert torch.allclose(da, da_t, atol=tol, rtol=tol)
    assert torch.allclose(db, db_t, atol=tol, rtol=tol)
369

370
371
372
373

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
nv-dlasalle's avatar
nv-dlasalle committed
374
@parametrize_idtype
375
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
376
377
def test_gather_mm_idx_b(idtype, feat_size):
    import torch
378

379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    dev = F.ctx()
    # input
    a = torch.tensor(np.random.rand(100, feat_size)).to(dev)
    a.requires_grad_()
    b = torch.tensor(np.random.rand(10, feat_size, feat_size + 1)).to(dev)
    b.requires_grad_()
    idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev).long()
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev)
    # compute
    c = dgl.ops.gather_mm(a, b, idx_b=idx)
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = torch.bmm(a.unsqueeze(1), b[idx]).squeeze(1)
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

    assert torch.allclose(c, c_t, atol=1e-4, rtol=1e-4)
    assert torch.allclose(da, da_t, atol=1e-4, rtol=1e-4)
    assert torch.allclose(db, db_t, atol=1e-4, rtol=1e-4)

404
405
406
407

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
nv-dlasalle's avatar
nv-dlasalle committed
408
@parametrize_idtype
409
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
410
411
412
def _test_gather_mm_idx_a(idtype, feat_size):
    # TODO(minjie): currently disabled due to bugs in the CUDA kernel. Need to fix it later.
    import torch
413

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    dev = F.ctx()
    # input
    a = torch.tensor(np.random.rand(10, feat_size)).to(dev)
    a.requires_grad_()
    b = torch.tensor(np.random.rand(100, feat_size, feat_size + 1)).to(dev)
    b.requires_grad_()
    idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev)
    dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev)
    # compute
    c = dgl.ops.gather_mm(a, b, idx_a=idx)
    c.backward(dc)
    da = a.grad.clone()
    db = b.grad.clone()
    # ground truth
    c_t = torch.bmm(a[idx].unsqueeze(1), b).squeeze(1)
    a.grad.zero_()
    b.grad.zero_()
    c_t.backward(dc)
    da_t = a.grad
    db_t = b.grad

    assert torch.allclose(c, c_t, atol=1e-4, rtol=1e-4)
    assert torch.allclose(da, da_t, atol=1e-4, rtol=1e-4)
    assert torch.allclose(db, db_t, atol=1e-4, rtol=1e-4)
438

439
440
441
442
443
444
445

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@unittest.skipIf(
    F._default_context_str == "gpu", reason="Libxsmm only fit in CPU."
)
446
447
def test_use_libxsmm_switch():
    import torch
448

449
450
451
452
453
454
455
456
457
458
459
460
    g = dgl.graph(([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]))
    x = torch.ones(3, 2, requires_grad=True)
    y = torch.arange(1, 13).float().view(6, 2).requires_grad_()

    assert dgl.is_libxsmm_enabled()
    dgl.ops.u_mul_e_sum(g, x, y)
    dgl.use_libxsmm(False)
    assert ~dgl.is_libxsmm_enabled()
    dgl.ops.u_mul_e_sum(g, x, y)
    dgl.use_libxsmm(True)
    assert dgl.is_libxsmm_enabled()
    dgl.ops.u_mul_e_sum(g, x, y)