test_sddmm.py 2.69 KB
Newer Older
1
2
3
4
5
import sys

import backend as F
import pytest
import torch
6

7
from dgl.sparse import bsddmm, sddmm
8
9
10
11
12

from .utils import clone_detach_and_grad, rand_coo, rand_csc, rand_csr


@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
13
@pytest.mark.parametrize("shape", [(5, 5), (5, 4)])
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@pytest.mark.parametrize("nnz", [2, 10])
@pytest.mark.parametrize("hidden", [1, 5])
def test_sddmm(create_func, shape, nnz, hidden):
    dev = F.ctx()
    A = create_func(shape, nnz, dev)
    if hidden > 1:
        B = torch.rand(shape[0], hidden, requires_grad=True, device=dev)
        C = torch.rand(hidden, shape[1], requires_grad=True, device=dev)
    else:
        B = torch.rand(shape[0], requires_grad=True, device=dev)
        C = torch.rand(shape[1], requires_grad=True, device=dev)

    A_val_clone = clone_detach_and_grad(A.val)
    dense_B = clone_detach_and_grad(B)
    dense_C = clone_detach_and_grad(C)

    sparse_result = sddmm(A, B, C)

    grad = torch.rand_like(sparse_result.val)
    sparse_result.val.backward(grad)

    if hidden == 1:
        dense_result = dense_B.view(-1, 1) @ dense_C.view(1, -1)
    else:
        dense_result = dense_B @ dense_C

    row, col = A.coo()
    dense_val = dense_result[row, col] * A_val_clone
    dense_val.backward(grad)

44
45
46
47
    assert torch.allclose(dense_val, sparse_result.val, atol=1e-05)
    assert torch.allclose(dense_C.grad, C.grad, atol=1e-05)
    assert torch.allclose(dense_B.grad, B.grad, atol=1e-05)
    assert torch.allclose(A_val_clone.grad, A.val.grad, atol=1e-05)
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80


@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(5, 5), (5, 4)])
@pytest.mark.parametrize("nnz", [2, 10])
@pytest.mark.parametrize("nz_dim", [2, 10])
def test_bsddmm(create_func, shape, nnz, nz_dim):
    dev = F.ctx()
    hidden = 2
    A = create_func(shape, nnz, dev, nz_dim)
    B = torch.rand(shape[0], hidden, nz_dim, requires_grad=True, device=dev)
    C = torch.rand(hidden, shape[1], nz_dim, requires_grad=True, device=dev)

    A_val_clone = clone_detach_and_grad(A.val)
    dense_B = clone_detach_and_grad(B)
    dense_C = clone_detach_and_grad(C)

    sparse_result = bsddmm(A, B, C)

    grad = torch.rand_like(sparse_result.val)
    sparse_result.val.backward(grad)

    dense_result = dense_B.permute(2, 0, 1) @ dense_C.permute(2, 0, 1)
    dense_result = dense_result.permute(1, 2, 0)

    row, col = A.coo()
    dense_val = dense_result[row, col] * A_val_clone
    dense_val.backward(grad)

    assert torch.allclose(dense_val, sparse_result.val, atol=1e-05)
    assert torch.allclose(dense_C.grad, C.grad, atol=1e-05)
    assert torch.allclose(dense_B.grad, B.grad, atol=1e-05)
    assert torch.allclose(A_val_clone.grad, A.val.grad, atol=1e-05)