Commit 1cb25232 authored by limm's avatar limm
Browse files

push 0.6.15 version

parent e8309f27
from itertools import product
import pytest
import torch
from torch_sparse import SparseTensor, add
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_add(dtype, device):
rowA = torch.tensor([0, 0, 1, 2, 2], device=device)
colA = torch.tensor([0, 2, 1, 0, 1], device=device)
valueA = tensor([1, 2, 4, 1, 3], dtype, device)
A = SparseTensor(row=rowA, col=colA, value=valueA)
rowB = torch.tensor([0, 0, 1, 2, 2], device=device)
colB = torch.tensor([1, 2, 2, 1, 2], device=device)
valueB = tensor([2, 3, 1, 2, 4], dtype, device)
B = SparseTensor(row=rowB, col=colB, value=valueB)
C = A + B
rowC, colC, valueC = C.coo()
assert rowC.tolist() == [0, 0, 0, 1, 1, 2, 2, 2]
assert colC.tolist() == [0, 1, 2, 1, 2, 0, 1, 2]
assert valueC.tolist() == [1, 2, 5, 4, 1, 1, 5, 4]
@torch.jit.script
def jit_add(A: SparseTensor, B: SparseTensor) -> SparseTensor:
return add(A, B)
jit_add(A, B)
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.cat import cat
from .utils import devices, tensor
@pytest.mark.parametrize('device', devices)
def test_cat(device):
row, col = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
mat1 = SparseTensor(row=row, col=col)
mat1.fill_cache_()
row, col = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
mat2 = SparseTensor(row=row, col=col)
mat2.fill_cache_()
out = cat([mat1, mat2], dim=0)
assert out.to_dense().tolist() == [[1, 1, 0], [0, 0, 1], [1, 1, 0],
[0, 1, 0], [1, 0, 0]]
assert out.storage.has_row()
assert out.storage.has_rowptr()
assert out.storage.has_rowcount()
assert out.storage.num_cached_keys() == 1
out = cat([mat1, mat2], dim=1)
assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1],
[0, 0, 0, 1, 0]]
assert out.storage.has_row()
assert not out.storage.has_rowptr()
assert out.storage.num_cached_keys() == 2
out = cat([mat1, mat2], dim=(0, 1))
assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
[0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
[0, 0, 0, 1, 0]]
assert out.storage.has_row()
assert out.storage.has_rowptr()
assert out.storage.num_cached_keys() == 5
value = torch.randn((mat1.nnz(), 4), device=device)
mat1 = mat1.set_value_(value, layout='coo')
out = cat([mat1, mat1], dim=-1)
assert out.storage.value().size() == (mat1.nnz(), 8)
assert out.storage.has_row()
assert out.storage.has_rowptr()
assert out.storage.num_cached_keys() == 5
import torch
from torch_sparse import coalesce
def test_coalesce():
row = torch.tensor([1, 0, 1, 0, 2, 1])
col = torch.tensor([0, 1, 1, 1, 0, 0])
index = torch.stack([row, col], dim=0)
index, _ = coalesce(index, None, m=3, n=2)
assert index.tolist() == [[0, 1, 1, 2], [1, 0, 1, 0]]
def test_coalesce_add():
row = torch.tensor([1, 0, 1, 0, 2, 1])
col = torch.tensor([0, 1, 1, 1, 0, 0])
index = torch.stack([row, col], dim=0)
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
index, value = coalesce(index, value, m=3, n=2)
assert index.tolist() == [[0, 1, 1, 2], [1, 0, 1, 0]]
assert value.tolist() == [[6, 8], [7, 9], [3, 4], [5, 6]]
def test_coalesce_max():
row = torch.tensor([1, 0, 1, 0, 2, 1])
col = torch.tensor([0, 1, 1, 1, 0, 0])
index = torch.stack([row, col], dim=0)
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
index, value = coalesce(index, value, m=3, n=2, op='max')
assert index.tolist() == [[0, 1, 1, 2], [1, 0, 1, 0]]
assert value.tolist() == [[4, 5], [6, 7], [3, 4], [5, 6]]
import torch
from torch_sparse import to_scipy, from_scipy
from torch_sparse import to_torch_sparse, from_torch_sparse
def test_convert_scipy():
index = torch.tensor([[0, 0, 1, 2, 2], [0, 2, 1, 0, 1]])
value = torch.Tensor([1, 2, 4, 1, 3])
N = 3
out = from_scipy(to_scipy(index, value, N, N))
assert out[0].tolist() == index.tolist()
assert out[1].tolist() == value.tolist()
def test_convert_torch_sparse():
index = torch.tensor([[0, 0, 1, 2, 2], [0, 2, 1, 0, 1]])
value = torch.Tensor([1, 2, 4, 1, 3])
N = 3
out = from_torch_sparse(to_torch_sparse(index, value, N, N).coalesce())
assert out[0].tolist() == index.tolist()
assert out[1].tolist() == value.tolist()
from itertools import product
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_remove_diag(dtype, device):
row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(row=row, col=col, value=value)
mat.fill_cache_()
mat = mat.remove_diag()
assert mat.storage.row().tolist() == [0, 1]
assert mat.storage.col().tolist() == [1, 2]
assert mat.storage.value().tolist() == [2, 3]
assert mat.storage.num_cached_keys() == 2
assert mat.storage.rowcount().tolist() == [1, 1, 0]
assert mat.storage.colcount().tolist() == [0, 1, 1]
mat = SparseTensor(row=row, col=col, value=value)
mat.fill_cache_()
mat = mat.remove_diag(k=1)
assert mat.storage.row().tolist() == [0, 2]
assert mat.storage.col().tolist() == [0, 2]
assert mat.storage.value().tolist() == [1, 4]
assert mat.storage.num_cached_keys() == 2
assert mat.storage.rowcount().tolist() == [1, 0, 1]
assert mat.storage.colcount().tolist() == [1, 0, 1]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_set_diag(dtype, device):
row, col = tensor([[0, 0, 9, 9], [0, 1, 0, 1]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(row=row, col=col, value=value)
mat = mat.set_diag(tensor([-8, -8], dtype, device), k=-1)
mat = mat.set_diag(tensor([-8], dtype, device), k=1)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_fill_diag(dtype, device):
row, col = tensor([[0, 0, 9, 9], [0, 1, 0, 1]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(row=row, col=col, value=value)
mat = mat.fill_diag(-8, k=-1)
mat = mat.fill_diag(-8, k=1)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_get_diag(dtype, device):
row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
value = tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype, device)
mat = SparseTensor(row=row, col=col, value=value)
assert mat.get_diag().tolist() == [[1, 1], [0, 0], [4, 4]]
row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
mat = SparseTensor(row=row, col=col)
assert mat.get_diag().tolist() == [1, 0, 1]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment