"...text-generation-inference.git" did not exist on "90b226db291769a45ecbccaa4f7384bc6b9bff8a"
Commit 2317ff66 authored by Mario Geiger's avatar Mario Geiger
Browse files

test_spspmm_2

parent 9f034684
...@@ -2,7 +2,7 @@ from itertools import product ...@@ -2,7 +2,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_sparse import spspmm, SparseTensor from torch_sparse import spspmm, SparseTensor, transpose
from .utils import grad_dtypes, devices, tensor from .utils import grad_dtypes, devices, tensor
...@@ -19,6 +19,38 @@ def test_spspmm(dtype, device): ...@@ -19,6 +19,38 @@ def test_spspmm(dtype, device):
assert valueC.tolist() == [8, 6, 8] assert valueC.tolist() == [8, 6, 8]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm_2(dtype, device):
row = torch.tensor(
[0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
device=device
)
col = torch.tensor(
[0, 5, 10, 15, 1, 2, 3, 7, 13, 6, 9, 5, 10, 15, 11, 14, 5, 15],
device=device
)
value = torch.tensor(
[1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5,
-2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5,
-2**-0.5, 2**-0.5, -2**-0.5],
dtype=dtype, device=device
)
index = torch.stack([row, col])
m = value.new_zeros(10, 16)
m[index[0], index[1]] = value
index_t, value_t = transpose(index, value, 10, 16)
index, value = spspmm(index, value, index_t, value_t, 10, 16, 10)
mask = value.abs() > 1e-4
index, value = index[:, mask], value[mask]
assert index.tolist() == [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]
assert value.tolist() == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_sparse_tensor_spspmm(dtype, device): def test_sparse_tensor_spspmm(dtype, device):
x = SparseTensor( x = SparseTensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment