Commit 4443f0d8 authored by rusty1s's avatar rusty1s
Browse files

complete tests

parent 296b5048
...@@ -2,74 +2,47 @@ from itertools import product ...@@ -2,74 +2,47 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck
from torch_sparse import SparseTensor, spspmm, to_value from torch_sparse import SparseTensor, spspmm, to_value
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
devices = [torch.device('cpu')] tests = [{
dtypes = [torch.double] 'name': 'Test coalesced input',
'indexA': [[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]],
'valueA': [1, 2, 3, 4, 5],
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) 'sizeA': [3, 3],
def test_coalesced_spspmm(dtype, device): 'indexB': [[0, 2], [1, 0]],
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device) 'valueB': [2, 4],
valueA = tensor([1, 2, 3, 4, 5], dtype, device) 'sizeB': [3, 2],
sizeA = torch.Size([3, 3]) }, {
A = torch.sparse_coo_tensor(indexA, valueA, sizeA, device=device) 'name': 'Test uncoalesced input',
'indexA': [[2, 2, 1, 0, 2, 0], [1, 1, 0, 2, 0, 1]],
indexB = torch.tensor([[0, 2], [1, 0]], device=device) 'valueA': [3, 2, 3, 2, 4, 1],
valueB = tensor([2, 4], dtype, device) 'sizeA': [3, 3],
sizeB = torch.Size([3, 2]) 'indexB': [[2, 0, 2], [0, 1, 0]],
B = torch.sparse_coo_tensor(indexB, valueB, sizeB, device=device) 'valueB': [2, 2, 2],
'sizeB': [3, 2],
assert spspmm(A, B).to_dense().tolist() == [[8, 0], [0, 6], [0, 8]] }]
# A.requires_grad_()
# B.requires_grad_() @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_spspmm(test, dtype, device):
# A.requires_grad_() indexA = torch.tensor(test['indexA'], device=device)
# B.requires_grad_() valueA = tensor(test['valueA'], dtype, device, requires_grad=True)
sizeA = torch.Size(test['sizeA'])
# to_value(C).sum().backward() A = SparseTensor(indexA, valueA, sizeA)
# print(valueA) denseA = A.detach().to_dense().requires_grad_()
# print(valueA.grad)
# print(valueB) indexB = torch.tensor(test['indexB'], device=device)
# print(valueB.grad) valueB = tensor(test['valueB'], dtype, device, requires_grad=True)
sizeB = torch.Size(test['sizeB'])
# A_dense.requires_grad_() B = SparseTensor(indexB, valueB, sizeB)
# B_dense.requires_grad_() denseB = B.detach().to_dense().requires_grad_()
# C_dense = torch.matmul(A_dense, B_dense) C = spspmm(A, B)
# C_dense[C_dense > 0].sum().backward() denseC = torch.matmul(denseA, denseB)
# print(A_dense) assert C.detach().to_dense().tolist() == denseC.tolist()
# print(A_dense.grad)
# print(B_dense) to_value(C).sum().backward()
# print(B_dense.grad) denseC.sum().backward()
assert valueA.grad.tolist() == denseA.grad[indexA[0], indexA[1]].tolist()
# A.requires_grad_()
# B = B.to_dense()
# B.requires_grad_()
# torch.spmm(A, B).sum().backward()
# print(B.grad)
# valueA.requires_grad_()
valueB.requires_grad_()
def pipeline(valueA, valueB):
A = SparseTensor(indexA, valueA, sizeA)
B = SparseTensor(indexB, valueB, sizeB)
C = spspmm(A, B)
value = to_value(C)
return value
# out = pipeline(valueA, valueB).sum().backward()
# print(valueA.grad)
# print(valueB.grad)
print(gradcheck(pipeline, (valueA, valueB), eps=1e-6, atol=1e-4))
# A, B = Sparsetensor(SparseTensor(index, valueB, sizeB)
# print(A.requires_grad)
# to_value(C).sum().backward()
...@@ -53,7 +53,7 @@ def to_scipy(A): ...@@ -53,7 +53,7 @@ def to_scipy(A):
def from_scipy(A): def from_scipy(A):
A = A.tocoo() A = A.tocoo()
row, col, value, size = A.row, A.col, A.data, torch.Size(A.shape) row, col, value, size = A.row, A.col, A.data, torch.Size(A.shape)
row, col = torch.from_numpy(row).long(), torch.from_numpy(col).long()
value = torch.from_numpy(value) value = torch.from_numpy(value)
index = torch.stack([torch.from_numpy(row), torch.from_numpy(col)], dim=0) index = torch.stack([row, col], dim=0)
index = index.to(torch.long)
return torch.sparse_coo_tensor(index, value, size) return torch.sparse_coo_tensor(index, value, size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment