test_matmul.py 1.6 KB
Newer Older
rusty1s's avatar
to csr  
rusty1s committed
1
2
3
from itertools import product

import pytest
rusty1s's avatar
rusty1s committed
4
import torch
rusty1s's avatar
rusty1s committed
5
from torch.autograd import gradcheck
rusty1s's avatar
rusty1s committed
6
from torch_sparse import spspmm
rusty1s's avatar
rusty1s committed
7
from torch_sparse.matmul import SpSpMM
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
to csr  
rusty1s committed
9
10
from .utils import dtypes, devices, tensor

rusty1s's avatar
rusty1s committed
11
dtypes = [torch.double]
rusty1s's avatar
rusty1s committed
12
13


rusty1s's avatar
rusty1s committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_coalesced_spspmm(dtype, device):
    indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
    valueA = tensor([1, 2, 3, 4, 5], dtype, device, requires_grad=True)
    sizeA = torch.Size([3, 3])
    A = (indexA, valueA, sizeA)
    A_dense = torch.sparse_coo_tensor(indexA, valueA, sizeA).to_dense()
    A_dense = A_dense.requires_grad_()
    print('A', A_dense)

    indexB = torch.tensor([[0, 2], [1, 0]], device=device)
    valueB = tensor([2, 4], dtype, device, requires_grad=True)
    sizeB = torch.Size([3, 2])
    B = (indexB, valueB, sizeB)
    B_dense = torch.sparse_coo_tensor(indexB, valueB, sizeB).to_dense()
    B_dense = B_dense.requires_grad_()
rusty1s's avatar
rusty1s committed
30

rusty1s's avatar
rusty1s committed
31
32
    index, value, size = spspmm(*A, *B)
    # out = torch.sparse_coo_tensor(index, value, size)
rusty1s's avatar
rusty1s committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    expected = torch.matmul(A_dense, B_dense)
    # assert out.to_dense().tolist() == expected.tolist()

    # valueA = valueA.requires_grad_()
    # valueB = valueB.requires_grad_()
    # data = (indexA, valueA, sizeA, indexB, valueB, sizeB)
    # assert gradcheck(SpSpMM.apply, data, eps=1e-6, atol=1e-4) is True

    # print(expected)

    value.sum().backward()
    expected.sum().backward()

    print(valueA.grad)
    print(A_dense.grad)

    # print(valueB.grad)
    # print(B_dense.grad)
rusty1s's avatar
rusty1s committed
51

rusty1s's avatar
rusty1s committed
52
53
    # # TODO TEST backward
    # # value.sum().backward()