test_matmul.py 3.47 KB
Newer Older
1
2
3
4
5
6
import sys

import backend as F
import pytest
import torch

7
from dgl.mock_sparse2 import create_from_coo, val_like
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

from .utils import (
    clone_detach_and_grad,
    rand_coo,
    rand_csc,
    rand_csr,
    sparse_matrix_to_dense,
    sparse_matrix_to_torch_sparse,
)

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
    pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(2, 7), (5, 2)])
@pytest.mark.parametrize("nnz", [1, 10])
@pytest.mark.parametrize("out_dim", [None, 10])
def test_spmm(create_func, shape, nnz, out_dim):
    dev = F.ctx()
    A = create_func(shape, nnz, dev)
    if out_dim is not None:
        X = torch.randn(shape[1], out_dim, requires_grad=True, device=dev)
    else:
        X = torch.randn(shape[1], requires_grad=True, device=dev)

    sparse_result = A @ X
    grad = torch.randn_like(sparse_result)
    sparse_result.backward(grad)

    adj = sparse_matrix_to_torch_sparse(A)
    XX = clone_detach_and_grad(X)
    torch_sparse_result = torch.sparse.mm(
        adj, XX.view(-1, 1) if out_dim is None else XX
    )
    if out_dim is None:
        torch_sparse_result = torch_sparse_result.view(-1)
    torch_sparse_result.backward(grad)
47
48
    assert torch.allclose(sparse_result, torch_sparse_result, atol=1e-05)
    assert torch.allclose(X.grad, XX.grad, atol=1e-05)
49
50
51
    assert torch.allclose(
        adj.grad.coalesce().to_dense(),
        sparse_matrix_to_dense(val_like(A, A.val.grad)),
52
        atol=1e-05,
53
    )
54
55
56
57
58
59
60
61


@pytest.mark.parametrize("create_func1", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("create_func2", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape_n_m", [(5, 5), (5, 6)])
@pytest.mark.parametrize("shape_k", [3, 4])
@pytest.mark.parametrize("nnz1", [1, 10])
@pytest.mark.parametrize("nnz2", [1, 10])
62
def test_spspmm(create_func1, create_func2, shape_n_m, shape_k, nnz1, nnz2):
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    dev = F.ctx()
    shape1 = shape_n_m
    shape2 = (shape_n_m[1], shape_k)
    A1 = create_func1(shape1, nnz1, dev)
    A2 = create_func2(shape2, nnz2, dev)
    A3 = A1 @ A2
    grad = torch.randn_like(A3.val)
    A3.val.backward(grad)

    torch_A1 = sparse_matrix_to_torch_sparse(A1)
    torch_A2 = sparse_matrix_to_torch_sparse(A2)
    torch_A3 = torch.sparse.mm(torch_A1, torch_A2)
    torch_A3_grad = sparse_matrix_to_torch_sparse(A3, grad)
    torch_A3.backward(torch_A3_grad)

    with torch.no_grad():
        assert torch.allclose(A3.dense(), torch_A3.to_dense(), atol=1e-05)
        assert torch.allclose(
            val_like(A1, A1.val.grad).dense(),
            torch_A1.grad.to_dense(),
            atol=1e-05,
        )
        assert torch.allclose(
            val_like(A2, A2.val.grad).dense(),
            torch_A2.grad.to_dense(),
            atol=1e-05,
        )
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119


def test_spspmm_duplicate():
    dev = F.ctx()

    row = torch.tensor([1, 0, 0, 0, 1]).to(dev)
    col = torch.tensor([1, 1, 1, 2, 2]).to(dev)
    val = torch.randn(len(row)).to(dev)
    shape = (4, 4)
    A1 = create_from_coo(row, col, val, shape)

    row = torch.tensor([1, 0, 0, 1]).to(dev)
    col = torch.tensor([1, 1, 2, 2]).to(dev)
    val = torch.randn(len(row)).to(dev)
    shape = (4, 4)
    A2 = create_from_coo(row, col, val, shape)

    try:
        A1 @ A2
    except:
        pass
    else:
        assert False, "Should raise error."

    try:
        A2 @ A1
    except:
        pass
    else:
        assert False, "Should raise error."