test_spspmm.py 1.24 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from itertools import product

import pytest
import torch
from torch_sparse import spspmm

from .utils import dtypes, devices, tensor


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spspmm(dtype, device):
    indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
    valueA = tensor([1, 2, 3, 4, 5], dtype, device)
    sizeA = torch.Size([3, 3])
    indexB = torch.tensor([[0, 2], [1, 0]], device=device)
    valueB = tensor([2, 4], dtype, device)
    sizeB = torch.Size([3, 2])

    indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
    assert indexC.tolist() == [[0, 1, 2], [0, 1, 1]]
    assert valueC.tolist() == [8, 6, 8]

    A = torch.sparse_coo_tensor(indexA, valueA, sizeA, device=device)
    A = A.to_dense().requires_grad_()
    B = torch.sparse_coo_tensor(indexB, valueB, sizeB, device=device)
    B = B.to_dense().requires_grad_()
    torch.matmul(A, B).sum().backward()

    valueA = valueA.requires_grad_()
    valueB = valueB.requires_grad_()
    indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
    valueC.sum().backward()

    assert valueA.grad.tolist() == A.grad[indexA[0], indexA[1]].tolist()
    assert valueB.grad.tolist() == B.grad[indexB[0], indexB[1]].tolist()