test_matmul.py 6.6 KB
Newer Older
1
2
3
4
5
6
import sys

import backend as F
import pytest
import torch

7
8
from dgl.sparse import bspmm, diag, from_coo, val_like
from dgl.sparse.matmul import matmul
9
10
11

from .utils import (
    clone_detach_and_grad,
12
    dense_mask,
13
14
15
    rand_coo,
    rand_csc,
    rand_csr,
16
    rand_stride,
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    sparse_matrix_to_dense,
    sparse_matrix_to_torch_sparse,
)


@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(2, 7), (5, 2)])
@pytest.mark.parametrize("nnz", [1, 10])
@pytest.mark.parametrize("out_dim", [None, 10])
def test_spmm(create_func, shape, nnz, out_dim):
    dev = F.ctx()
    A = create_func(shape, nnz, dev)
    if out_dim is not None:
        X = torch.randn(shape[1], out_dim, requires_grad=True, device=dev)
    else:
        X = torch.randn(shape[1], requires_grad=True, device=dev)

34
    X = rand_stride(X)
35
    sparse_result = matmul(A, X)
36
37
38
    grad = torch.randn_like(sparse_result)
    sparse_result.backward(grad)

39
    adj = sparse_matrix_to_dense(A)
40
    XX = clone_detach_and_grad(X)
41
    dense_result = torch.matmul(adj, XX)
42
    if out_dim is None:
43
44
45
        dense_result = dense_result.view(-1)
    dense_result.backward(grad)
    assert torch.allclose(sparse_result, dense_result, atol=1e-05)
46
    assert torch.allclose(X.grad, XX.grad, atol=1e-05)
47
    assert torch.allclose(
48
        dense_mask(adj.grad, A),
49
        sparse_matrix_to_dense(val_like(A, A.val.grad)),
50
        atol=1e-05,
51
    )
52
53


54
55
56
57
58
59
60
@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(2, 7), (5, 2)])
@pytest.mark.parametrize("nnz", [1, 10])
def test_bspmm(create_func, shape, nnz):
    dev = F.ctx()
    A = create_func(shape, nnz, dev, 2)
    X = torch.randn(shape[1], 10, 2, requires_grad=True, device=dev)
61
    X = rand_stride(X)
62

63
    sparse_result = matmul(A, X)
64
65
66
67
    grad = torch.randn_like(sparse_result)
    sparse_result.backward(grad)

    XX = clone_detach_and_grad(X)
68
    torch_A = A.to_dense().clone().detach().requires_grad_()
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    torch_result = torch_A.permute(2, 0, 1) @ XX.permute(2, 0, 1)

    torch_result.backward(grad.permute(2, 0, 1))
    assert torch.allclose(
        sparse_result.permute(2, 0, 1), torch_result, atol=1e-05
    )
    assert torch.allclose(X.grad, XX.grad, atol=1e-05)
    assert torch.allclose(
        dense_mask(torch_A.grad, A),
        sparse_matrix_to_dense(val_like(A, A.val.grad)),
        atol=1e-05,
    )


83
84
85
86
87
88
@pytest.mark.parametrize("create_func1", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("create_func2", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape_n_m", [(5, 5), (5, 6)])
@pytest.mark.parametrize("shape_k", [3, 4])
@pytest.mark.parametrize("nnz1", [1, 10])
@pytest.mark.parametrize("nnz2", [1, 10])
89
def test_spspmm(create_func1, create_func2, shape_n_m, shape_k, nnz1, nnz2):
90
91
92
93
94
    dev = F.ctx()
    shape1 = shape_n_m
    shape2 = (shape_n_m[1], shape_k)
    A1 = create_func1(shape1, nnz1, dev)
    A2 = create_func2(shape2, nnz2, dev)
95
    A3 = matmul(A1, A2)
96
97
98
99
100
101
102
103
104
105
    grad = torch.randn_like(A3.val)
    A3.val.backward(grad)

    torch_A1 = sparse_matrix_to_torch_sparse(A1)
    torch_A2 = sparse_matrix_to_torch_sparse(A2)
    torch_A3 = torch.sparse.mm(torch_A1, torch_A2)
    torch_A3_grad = sparse_matrix_to_torch_sparse(A3, grad)
    torch_A3.backward(torch_A3_grad)

    with torch.no_grad():
106
        assert torch.allclose(A3.to_dense(), torch_A3.to_dense(), atol=1e-05)
107
        assert torch.allclose(
108
            val_like(A1, A1.val.grad).to_dense(),
109
110
111
112
            torch_A1.grad.to_dense(),
            atol=1e-05,
        )
        assert torch.allclose(
113
            val_like(A2, A2.val.grad).to_dense(),
114
115
116
            torch_A2.grad.to_dense(),
            atol=1e-05,
        )
117
118
119
120
121
122
123
124
125


def test_spspmm_duplicate():
    dev = F.ctx()

    row = torch.tensor([1, 0, 0, 0, 1]).to(dev)
    col = torch.tensor([1, 1, 1, 2, 2]).to(dev)
    val = torch.randn(len(row)).to(dev)
    shape = (4, 4)
126
    A1 = from_coo(row, col, val, shape)
127
128
129
130
131

    row = torch.tensor([1, 0, 0, 1]).to(dev)
    col = torch.tensor([1, 1, 2, 2]).to(dev)
    val = torch.randn(len(row)).to(dev)
    shape = (4, 4)
132
    A2 = from_coo(row, col, val, shape)
133
134

    try:
135
        matmul(A1, A2)
136
137
138
139
140
141
    except:
        pass
    else:
        assert False, "Should raise error."

    try:
142
        matmul(A2, A1)
143
144
145
146
    except:
        pass
    else:
        assert False, "Should raise error."
147
148
149
150
151
152
153
154
155
156
157


@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("sparse_shape", [(5, 5), (5, 6)])
@pytest.mark.parametrize("nnz", [1, 10])
def test_sparse_diag_mm(create_func, sparse_shape, nnz):
    dev = F.ctx()
    diag_shape = sparse_shape[1], sparse_shape[1]
    A = create_func(sparse_shape, nnz, dev)
    diag_val = torch.randn(sparse_shape[1], device=dev, requires_grad=True)
    D = diag(diag_val, diag_shape)
158
    B = matmul(A, D)
159
160
161
162
    grad = torch.randn_like(B.val)
    B.val.backward(grad)

    torch_A = sparse_matrix_to_torch_sparse(A)
163
    torch_D = sparse_matrix_to_torch_sparse(D)
164
165
166
167
168
    torch_B = torch.sparse.mm(torch_A, torch_D)
    torch_B_grad = sparse_matrix_to_torch_sparse(B, grad)
    torch_B.backward(torch_B_grad)

    with torch.no_grad():
169
        assert torch.allclose(B.to_dense(), torch_B.to_dense(), atol=1e-05)
170
        assert torch.allclose(
171
            val_like(A, A.val.grad).to_dense(),
172
173
174
175
            torch_A.grad.to_dense(),
            atol=1e-05,
        )
        assert torch.allclose(
176
            diag(D.val.grad, D.shape).to_dense(),
177
178
179
180
181
182
183
184
185
186
187
188
189
190
            torch_D.grad.to_dense(),
            atol=1e-05,
        )


@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("sparse_shape", [(5, 5), (5, 6)])
@pytest.mark.parametrize("nnz", [1, 10])
def test_diag_sparse_mm(create_func, sparse_shape, nnz):
    dev = F.ctx()
    diag_shape = sparse_shape[0], sparse_shape[0]
    A = create_func(sparse_shape, nnz, dev)
    diag_val = torch.randn(sparse_shape[0], device=dev, requires_grad=True)
    D = diag(diag_val, diag_shape)
191
    B = matmul(D, A)
192
193
194
195
    grad = torch.randn_like(B.val)
    B.val.backward(grad)

    torch_A = sparse_matrix_to_torch_sparse(A)
196
    torch_D = sparse_matrix_to_torch_sparse(D)
197
198
199
200
201
    torch_B = torch.sparse.mm(torch_D, torch_A)
    torch_B_grad = sparse_matrix_to_torch_sparse(B, grad)
    torch_B.backward(torch_B_grad)

    with torch.no_grad():
202
        assert torch.allclose(B.to_dense(), torch_B.to_dense(), atol=1e-05)
203
        assert torch.allclose(
204
            val_like(A, A.val.grad).to_dense(),
205
206
207
208
            torch_A.grad.to_dense(),
            atol=1e-05,
        )
        assert torch.allclose(
209
            diag(D.val.grad, D.shape).to_dense(),
210
211
212
            torch_D.grad.to_dense(),
            atol=1e-05,
        )