"vscode:/vscode.git/clone" did not exist on "40528e9ae7d56740c00d838299198d34111717bb"
Commit 9f034684 authored by Mario Geiger's avatar Mario Geiger
Browse files

test_sparse_tensor_spspmm

parent 57852a66
......@@ -2,7 +2,7 @@ from itertools import product
import pytest
import torch
from torch_sparse import spspmm
from torch_sparse import spspmm, SparseTensor
from .utils import grad_dtypes, devices, tensor
......@@ -17,3 +17,32 @@ def test_spspmm(dtype, device):
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
assert indexC.tolist() == [[0, 1, 2], [0, 1, 1]]
assert valueC.tolist() == [8, 6, 8]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_sparse_tensor_spspmm(dtype, device):
x = SparseTensor(
row=torch.tensor(
[0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
device=device
),
col=torch.tensor(
[0, 5, 10, 15, 1, 2, 3, 7, 13, 6, 9, 5, 10, 15, 11, 14, 5, 15],
device=device
),
value=torch.tensor(
[1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5,
-2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5,
-2**-0.5, 2**-0.5, -2**-0.5],
dtype=dtype, device=device
),
)
i0 = torch.eye(10, dtype=dtype, device=device)
i1 = x @ x.to_dense().t()
assert torch.allclose(i0, i1)
i1 = x @ x.t()
i1 = i1.to_dense()
assert torch.allclose(i0, i1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment