test_matmul.py 2.57 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
5
import torch_scatter
Matthias Fey's avatar
Matthias Fey committed
6

rusty1s's avatar
rusty1s committed
7
8
9
from torch_sparse.matmul import matmul
from torch_sparse.tensor import SparseTensor

rusty1s's avatar
rusty1s committed
10
from .utils import devices, grad_dtypes, reductions
rusty1s's avatar
rusty1s committed
11
12
13


@pytest.mark.parametrize('dtype,device,reduce',
rusty1s's avatar
rusty1s committed
14
                         product(grad_dtypes, devices, reductions))
rusty1s's avatar
rusty1s committed
15
def test_spmm(dtype, device, reduce):
Matthias Fey's avatar
Matthias Fey committed
16
17
18
    if device == torch.device('cuda:0') and dtype == torch.bfloat16:
        return  # Not yet implemented.

rusty1s's avatar
rusty1s committed
19
    src = torch.randn((10, 8), dtype=dtype, device=device)
rusty1s's avatar
rusty1s committed
20
21
    src[2:4, :] = 0  # Remove multiple rows.
    src[:, 2:4] = 0  # Remove multiple columns.
rusty1s's avatar
rusty1s committed
22
    src = SparseTensor.from_dense(src).requires_grad_()
rusty1s's avatar
rusty1s committed
23
    row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
    other = torch.randn((2, 8, 2), dtype=dtype, device=device,
rusty1s's avatar
rusty1s committed
26
27
                        requires_grad=True)

rusty1s's avatar
rusty1s committed
28
    src_col = other.index_select(-2, col) * value.unsqueeze(-1)
rusty1s's avatar
matmul  
rusty1s committed
29
    expected = torch_scatter.scatter(src_col, row, dim=-2, reduce=reduce)
rusty1s's avatar
rusty1s committed
30
31
32
    if reduce == 'min':
        expected[expected > 1000] = 0
    if reduce == 'max':
rusty1s's avatar
rusty1s committed
33
        expected[expected < -1000] = 0
rusty1s's avatar
rusty1s committed
34
35
36
37
38
39
40

    grad_out = torch.randn_like(expected)

    expected.backward(grad_out)
    expected_grad_value = value.grad
    value.grad = None
    expected_grad_other = other.grad
rusty1s's avatar
rusty1s committed
41
42
    other.grad = None

rusty1s's avatar
rusty1s committed
43
44
    out = matmul(src, other, reduce)
    out.backward(grad_out)
rusty1s's avatar
rusty1s committed
45

Matthias Fey's avatar
Matthias Fey committed
46
47
48
49
50
51
52
53
    if dtype == torch.float16 or dtype == torch.bfloat16:
        assert torch.allclose(expected, out, atol=1e-1)
        assert torch.allclose(expected_grad_value, value.grad, atol=1e-1)
        assert torch.allclose(expected_grad_other, other.grad, atol=1e-1)
    else:
        assert torch.allclose(expected, out)
        assert torch.allclose(expected_grad_value, value.grad)
        assert torch.allclose(expected_grad_other, other.grad)
rusty1s's avatar
rusty1s committed
54
55


rusty1s's avatar
rusty1s committed
56
57
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device):
Matthias Fey's avatar
Matthias Fey committed
58
59
60
    if device == torch.device('cuda:0') and dtype == torch.bfloat16:
        return  # Not yet implemented.

rusty1s's avatar
rusty1s committed
61
62
    src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
                       device=device)
rusty1s's avatar
rusty1s committed
63

rusty1s's avatar
rusty1s committed
64
65
66
67
68
69
70
71
    src = SparseTensor.from_dense(src)
    out = matmul(src, src)
    assert out.sizes() == [3, 3]
    assert out.has_value()
    rowptr, col, value = out.csr()
    assert rowptr.tolist() == [0, 1, 2, 3]
    assert col.tolist() == [0, 1, 2]
    assert value.tolist() == [1, 1, 1]
rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
rusty1s committed
73
74
75
76
77
78
79
    src.set_value_(None)
    out = matmul(src, src)
    assert out.sizes() == [3, 3]
    assert not out.has_value()
    rowptr, col, value = out.csr()
    assert rowptr.tolist() == [0, 1, 2, 3]
    assert col.tolist() == [0, 1, 2]