test_matmul.py 1.51 KB
Newer Older
rusty1s's avatar
to csr  
rusty1s committed
1
2
3
from itertools import product

import pytest
rusty1s's avatar
rusty1s committed
4
import torch
rusty1s's avatar
docs  
rusty1s committed
5
from torch_sparse import sparse_coo_tensor, spspmm, to_value
rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
to csr  
rusty1s committed
7
8
from .utils import dtypes, devices, tensor

rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
tests = [{
    'name': 'Test coalesced input',
    'indexA': [[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]],
    'valueA': [1, 2, 3, 4, 5],
    'sizeA': [3, 3],
    'indexB': [[0, 2], [1, 0]],
    'valueB': [2, 4],
    'sizeB': [3, 2],
}, {
    'name': 'Test uncoalesced input',
    'indexA': [[2, 2, 1, 0, 2, 0], [1, 1, 0, 2, 0, 1]],
    'valueA': [3, 2, 3, 2, 4, 1],
    'sizeA': [3, 3],
    'indexB': [[2, 0, 2], [0, 1, 0]],
    'valueB': [2, 2, 2],
    'sizeB': [3, 2],
}]


@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_spspmm(test, dtype, device):
    indexA = torch.tensor(test['indexA'], device=device)
    valueA = tensor(test['valueA'], dtype, device, requires_grad=True)
    sizeA = torch.Size(test['sizeA'])
rusty1s's avatar
docs  
rusty1s committed
33
    A = sparse_coo_tensor(indexA, valueA, sizeA)
rusty1s's avatar
rusty1s committed
34
35
36
37
38
    denseA = A.detach().to_dense().requires_grad_()

    indexB = torch.tensor(test['indexB'], device=device)
    valueB = tensor(test['valueB'], dtype, device, requires_grad=True)
    sizeB = torch.Size(test['sizeB'])
rusty1s's avatar
docs  
rusty1s committed
39
    B = sparse_coo_tensor(indexB, valueB, sizeB)
rusty1s's avatar
rusty1s committed
40
41
42
43
44
45
46
47
48
    denseB = B.detach().to_dense().requires_grad_()

    C = spspmm(A, B)
    denseC = torch.matmul(denseA, denseB)
    assert C.detach().to_dense().tolist() == denseC.tolist()

    to_value(C).sum().backward()
    denseC.sum().backward()
    assert valueA.grad.tolist() == denseA.grad[indexA[0], indexA[1]].tolist()