Commit ee2d323d authored by rusty1s's avatar rusty1s
Browse files

fixed 'index derivative is not defined' message

parent c1cd9753
from itertools import product
import pytest
import torch import torch
from torch_sparse import spmm from torch_sparse import spmm
from .utils import dtypes, devices, tensor
def test_spmm(): @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
row = torch.tensor([0, 0, 1, 2, 2]) def test_spmm(dtype, device):
col = torch.tensor([0, 2, 1, 0, 1]) row = torch.tensor([0, 0, 1, 2, 2], device=device)
col = torch.tensor([0, 2, 1, 0, 1], device=device)
index = torch.stack([row, col], dim=0) index = torch.stack([row, col], dim=0)
value = torch.tensor([1, 2, 4, 1, 3]) value = tensor([1, 2, 4, 1, 3], dtype, device)
x = tensor([[1, 4], [2, 5], [3, 6]], dtype, device)
matrix = torch.tensor([[1, 4], [2, 5], [3, 6]]) out = spmm(index, value, 3, x)
out = spmm(index, value, 3, matrix)
assert out.tolist() == [[7, 16], [8, 20], [7, 19]] assert out.tolist() == [[7, 16], [8, 20], [7, 19]]
from itertools import product
import pytest
import torch
from torch_sparse import spspmm, spmm
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spmm_spspmm(dtype, device):
row = torch.tensor([0, 0, 1, 2, 2], device=device)
col = torch.tensor([0, 2, 1, 0, 1], device=device)
index = torch.stack([row, col], dim=0)
value = tensor([1, 2, 4, 1, 3], dtype, device)
x = tensor([[1, 4], [2, 5], [3, 6]], dtype, device)
value = value.requires_grad_(True)
out_index, out_value = spspmm(index, value, index, value, 3, 3, 3)
out = spmm(out_index, out_value, 3, x)
assert out.size() == (3, 2)
...@@ -23,7 +23,8 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n): ...@@ -23,7 +23,8 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n):
:rtype: (:class:`LongTensor`, :class:`Tensor`) :rtype: (:class:`LongTensor`, :class:`Tensor`)
""" """
return SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n) index, value = SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)
return index.detach(), value
class SpSpMM(torch.autograd.Function): class SpSpMM(torch.autograd.Function):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment