test_matmul.py 6.56 KB
Newer Older
1
2
3
4
5
6
import sys

import backend as F
import pytest
import torch

7
8
from dgl.sparse import bspmm, diag, from_coo, val_like
from dgl.sparse.matmul import matmul
9
10
11

from .utils import (
    clone_detach_and_grad,
12
    dense_mask,
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    rand_coo,
    rand_csc,
    rand_csr,
    sparse_matrix_to_dense,
    sparse_matrix_to_torch_sparse,
)


@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(2, 7), (5, 2)])
@pytest.mark.parametrize("nnz", [1, 10])
@pytest.mark.parametrize("out_dim", [None, 10])
def test_spmm(create_func, shape, nnz, out_dim):
    dev = F.ctx()
    A = create_func(shape, nnz, dev)
    if out_dim is not None:
        X = torch.randn(shape[1], out_dim, requires_grad=True, device=dev)
    else:
        X = torch.randn(shape[1], requires_grad=True, device=dev)

33
    sparse_result = matmul(A, X)
34
35
36
    grad = torch.randn_like(sparse_result)
    sparse_result.backward(grad)

37
    adj = sparse_matrix_to_dense(A)
38
    XX = clone_detach_and_grad(X)
39
    dense_result = torch.matmul(adj, XX)
40
    if out_dim is None:
41
42
43
        dense_result = dense_result.view(-1)
    dense_result.backward(grad)
    assert torch.allclose(sparse_result, dense_result, atol=1e-05)
44
    assert torch.allclose(X.grad, XX.grad, atol=1e-05)
45
    assert torch.allclose(
46
        dense_mask(adj.grad, A),
47
        sparse_matrix_to_dense(val_like(A, A.val.grad)),
48
        atol=1e-05,
49
    )
50
51


52
53
54
55
56
57
58
59
@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(2, 7), (5, 2)])
@pytest.mark.parametrize("nnz", [1, 10])
def test_bspmm(create_func, shape, nnz):
    dev = F.ctx()
    A = create_func(shape, nnz, dev, 2)
    X = torch.randn(shape[1], 10, 2, requires_grad=True, device=dev)

60
    sparse_result = matmul(A, X)
61
62
63
64
    grad = torch.randn_like(sparse_result)
    sparse_result.backward(grad)

    XX = clone_detach_and_grad(X)
65
    torch_A = A.to_dense().clone().detach().requires_grad_()
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    torch_result = torch_A.permute(2, 0, 1) @ XX.permute(2, 0, 1)

    torch_result.backward(grad.permute(2, 0, 1))
    assert torch.allclose(
        sparse_result.permute(2, 0, 1), torch_result, atol=1e-05
    )
    assert torch.allclose(X.grad, XX.grad, atol=1e-05)
    assert torch.allclose(
        dense_mask(torch_A.grad, A),
        sparse_matrix_to_dense(val_like(A, A.val.grad)),
        atol=1e-05,
    )


80
81
82
83
84
85
@pytest.mark.parametrize("create_func1", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("create_func2", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape_n_m", [(5, 5), (5, 6)])
@pytest.mark.parametrize("shape_k", [3, 4])
@pytest.mark.parametrize("nnz1", [1, 10])
@pytest.mark.parametrize("nnz2", [1, 10])
86
def test_spspmm(create_func1, create_func2, shape_n_m, shape_k, nnz1, nnz2):
87
88
89
90
91
    dev = F.ctx()
    shape1 = shape_n_m
    shape2 = (shape_n_m[1], shape_k)
    A1 = create_func1(shape1, nnz1, dev)
    A2 = create_func2(shape2, nnz2, dev)
92
    A3 = matmul(A1, A2)
93
94
95
96
97
98
99
100
101
102
    grad = torch.randn_like(A3.val)
    A3.val.backward(grad)

    torch_A1 = sparse_matrix_to_torch_sparse(A1)
    torch_A2 = sparse_matrix_to_torch_sparse(A2)
    torch_A3 = torch.sparse.mm(torch_A1, torch_A2)
    torch_A3_grad = sparse_matrix_to_torch_sparse(A3, grad)
    torch_A3.backward(torch_A3_grad)

    with torch.no_grad():
103
        assert torch.allclose(A3.to_dense(), torch_A3.to_dense(), atol=1e-05)
104
        assert torch.allclose(
105
            val_like(A1, A1.val.grad).to_dense(),
106
107
108
109
            torch_A1.grad.to_dense(),
            atol=1e-05,
        )
        assert torch.allclose(
110
            val_like(A2, A2.val.grad).to_dense(),
111
112
113
            torch_A2.grad.to_dense(),
            atol=1e-05,
        )
114
115
116
117
118
119
120
121
122


def test_spspmm_duplicate():
    dev = F.ctx()

    row = torch.tensor([1, 0, 0, 0, 1]).to(dev)
    col = torch.tensor([1, 1, 1, 2, 2]).to(dev)
    val = torch.randn(len(row)).to(dev)
    shape = (4, 4)
123
    A1 = from_coo(row, col, val, shape)
124
125
126
127
128

    row = torch.tensor([1, 0, 0, 1]).to(dev)
    col = torch.tensor([1, 1, 2, 2]).to(dev)
    val = torch.randn(len(row)).to(dev)
    shape = (4, 4)
129
    A2 = from_coo(row, col, val, shape)
130
131

    try:
132
        matmul(A1, A2)
133
134
135
136
137
138
    except:
        pass
    else:
        assert False, "Should raise error."

    try:
139
        matmul(A2, A1)
140
141
142
143
    except:
        pass
    else:
        assert False, "Should raise error."
144
145
146
147
148
149
150
151
152
153
154


@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("sparse_shape", [(5, 5), (5, 6)])
@pytest.mark.parametrize("nnz", [1, 10])
def test_sparse_diag_mm(create_func, sparse_shape, nnz):
    dev = F.ctx()
    diag_shape = sparse_shape[1], sparse_shape[1]
    A = create_func(sparse_shape, nnz, dev)
    diag_val = torch.randn(sparse_shape[1], device=dev, requires_grad=True)
    D = diag(diag_val, diag_shape)
155
    B = matmul(A, D)
156
157
158
159
    grad = torch.randn_like(B.val)
    B.val.backward(grad)

    torch_A = sparse_matrix_to_torch_sparse(A)
160
    torch_D = sparse_matrix_to_torch_sparse(D.to_sparse())
161
162
163
164
165
    torch_B = torch.sparse.mm(torch_A, torch_D)
    torch_B_grad = sparse_matrix_to_torch_sparse(B, grad)
    torch_B.backward(torch_B_grad)

    with torch.no_grad():
166
        assert torch.allclose(B.to_dense(), torch_B.to_dense(), atol=1e-05)
167
        assert torch.allclose(
168
            val_like(A, A.val.grad).to_dense(),
169
170
171
172
            torch_A.grad.to_dense(),
            atol=1e-05,
        )
        assert torch.allclose(
173
            diag(D.val.grad, D.shape).to_dense(),
174
175
176
177
178
179
180
181
182
183
184
185
186
187
            torch_D.grad.to_dense(),
            atol=1e-05,
        )


@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("sparse_shape", [(5, 5), (5, 6)])
@pytest.mark.parametrize("nnz", [1, 10])
def test_diag_sparse_mm(create_func, sparse_shape, nnz):
    dev = F.ctx()
    diag_shape = sparse_shape[0], sparse_shape[0]
    A = create_func(sparse_shape, nnz, dev)
    diag_val = torch.randn(sparse_shape[0], device=dev, requires_grad=True)
    D = diag(diag_val, diag_shape)
188
    B = matmul(D, A)
189
190
191
192
    grad = torch.randn_like(B.val)
    B.val.backward(grad)

    torch_A = sparse_matrix_to_torch_sparse(A)
193
    torch_D = sparse_matrix_to_torch_sparse(D.to_sparse())
194
195
196
197
198
    torch_B = torch.sparse.mm(torch_D, torch_A)
    torch_B_grad = sparse_matrix_to_torch_sparse(B, grad)
    torch_B.backward(torch_B_grad)

    with torch.no_grad():
199
        assert torch.allclose(B.to_dense(), torch_B.to_dense(), atol=1e-05)
200
        assert torch.allclose(
201
            val_like(A, A.val.grad).to_dense(),
202
203
204
205
            torch_A.grad.to_dense(),
            atol=1e-05,
        )
        assert torch.allclose(
206
            diag(D.val.grad, D.shape).to_dense(),
207
208
209
            torch_D.grad.to_dense(),
            atol=1e-05,
        )