test_matmul.py 1.66 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
from itertools import product

import pytest
import torch

from torch_sparse.matmul import matmul
from torch_sparse.tensor import SparseTensor
import torch_scatter

rusty1s's avatar
rusty1s committed
10
from .utils import devices, grad_dtypes
rusty1s's avatar
rusty1s committed
11

rusty1s's avatar
rusty1s committed
12
devices = ['cpu', 'cuda']
rusty1s's avatar
rusty1s committed
13
grad_dtypes = [torch.float]
rusty1s's avatar
rusty1s committed
14
reductions = ['sum', 'mean', 'min', 'max']
rusty1s's avatar
rusty1s committed
15
reductions = ['min', 'max']
rusty1s's avatar
rusty1s committed
16
17
18


@pytest.mark.parametrize('dtype,device,reduce',
rusty1s's avatar
rusty1s committed
19
                         product(grad_dtypes, devices, reductions))
rusty1s's avatar
rusty1s committed
20
def test_spmm(dtype, device, reduce):
rusty1s's avatar
rusty1s committed
21
    src = torch.randn((10, 8), dtype=dtype, device=device)
rusty1s's avatar
rusty1s committed
22
23
    src[2:4, :] = 0  # Remove multiple rows.
    src[:, 2:4] = 0  # Remove multiple columns.
rusty1s's avatar
rusty1s committed
24
    src = SparseTensor.from_dense(src).requires_grad_()
rusty1s's avatar
rusty1s committed
25
    (row, col), value = src.coo()
rusty1s's avatar
rusty1s committed
26

rusty1s's avatar
rusty1s committed
27
    other = torch.randn((2, 8, 2), dtype=dtype, device=device,
rusty1s's avatar
rusty1s committed
28
29
                        requires_grad=True)

rusty1s's avatar
rusty1s committed
30
    src_col = other.index_select(-2, col) * value.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
31
    func = 'add' if reduce == 'sum' else reduce
rusty1s's avatar
rusty1s committed
32
33
34
35
36
    expected = getattr(torch_scatter, f'scatter_{func}')(src_col, row, dim=-2)
    expected = expected[0] if isinstance(expected, tuple) else expected
    if reduce == 'min':
        expected[expected > 1000] = 0
    if reduce == 'max':
rusty1s's avatar
rusty1s committed
37
        expected[expected < -1000] = 0
rusty1s's avatar
rusty1s committed
38
39
40
41
42
43
44

    grad_out = torch.randn_like(expected)

    expected.backward(grad_out)
    expected_grad_value = value.grad
    value.grad = None
    expected_grad_other = other.grad
rusty1s's avatar
rusty1s committed
45
46
    other.grad = None

rusty1s's avatar
rusty1s committed
47
48
49
    out = matmul(src, other, reduce)
    out = out[0] if isinstance(out, tuple) else out
    out.backward(grad_out)
rusty1s's avatar
rusty1s committed
50

rusty1s's avatar
rusty1s committed
51
52
53
    assert torch.allclose(expected, out)
    assert torch.allclose(expected_grad_value, value.grad)
    assert torch.allclose(expected_grad_other, other.grad)