"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "7f9279397af46465128aef2dd626de9376282b0f"
Commit ee2d323d authored by rusty1s's avatar rusty1s
Browse files

fixed 'index derivative is not defined' message

parent c1cd9753
from itertools import product
import pytest
import torch import torch
from torch_sparse import spmm from torch_sparse import spmm
from .utils import dtypes, devices, tensor
def test_spmm(): @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
row = torch.tensor([0, 0, 1, 2, 2]) def test_spmm(dtype, device):
col = torch.tensor([0, 2, 1, 0, 1]) row = torch.tensor([0, 0, 1, 2, 2], device=device)
col = torch.tensor([0, 2, 1, 0, 1], device=device)
index = torch.stack([row, col], dim=0) index = torch.stack([row, col], dim=0)
value = torch.tensor([1, 2, 4, 1, 3]) value = tensor([1, 2, 4, 1, 3], dtype, device)
x = tensor([[1, 4], [2, 5], [3, 6]], dtype, device)
matrix = torch.tensor([[1, 4], [2, 5], [3, 6]]) out = spmm(index, value, 3, x)
out = spmm(index, value, 3, matrix)
assert out.tolist() == [[7, 16], [8, 20], [7, 19]] assert out.tolist() == [[7, 16], [8, 20], [7, 19]]
from itertools import product
import pytest
import torch
from torch_sparse import spspmm, spmm
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spmm_spspmm(dtype, device):
row = torch.tensor([0, 0, 1, 2, 2], device=device)
col = torch.tensor([0, 2, 1, 0, 1], device=device)
index = torch.stack([row, col], dim=0)
value = tensor([1, 2, 4, 1, 3], dtype, device)
x = tensor([[1, 4], [2, 5], [3, 6]], dtype, device)
value = value.requires_grad_(True)
out_index, out_value = spspmm(index, value, index, value, 3, 3, 3)
out = spmm(out_index, out_value, 3, x)
assert out.size() == (3, 2)
...@@ -23,7 +23,8 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n): ...@@ -23,7 +23,8 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n):
:rtype: (:class:`LongTensor`, :class:`Tensor`) :rtype: (:class:`LongTensor`, :class:`Tensor`)
""" """
return SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n) index, value = SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)
return index.detach(), value
class SpSpMM(torch.autograd.Function): class SpSpMM(torch.autograd.Function):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment