Commit bc522dd9 authored by rusty1s's avatar rusty1s
Browse files

cat tests

parent 69cab8ac
from itertools import product
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.cat import cat
from .utils import dtypes, devices, tensor
from .utils import devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_cat(dtype, device):
@pytest.mark.parametrize('device', devices)
def test_cat(device):
index = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
mat1 = SparseTensor(index)
mat1.fill_cache_()
index = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
mat2 = SparseTensor(index)
mat2.fill_cache_()
out = cat([mat1, mat2], dim=0)
assert out.to_dense().tolist() == [[1, 1, 0], [0, 0, 1], [1, 1, 0],
[0, 1, 0], [1, 0, 0]]
assert len(out.storage.cached_keys()) == 2
assert out.storage.has_rowcount()
assert out.storage.has_rowptr()
out = cat([mat1, mat2], dim=1)
assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1],
[0, 0, 0, 1, 0]]
assert len(out.storage.cached_keys()) == 2
assert out.storage.has_colcount()
assert out.storage.has_colptr()
out = cat([mat1, mat2], dim=(0, 1))
assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
[0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
[0, 0, 0, 1, 0]]
assert len(out.storage.cached_keys()) == 6
cat([mat1, mat2], dim=(0, 1))
mat1.set_value_(torch.randn((mat1.nnz(), 4), device=device))
out = cat([mat1, mat1], dim=-1)
assert out.storage.value.size() == (mat1.nnz(), 8)
assert len(out.storage.cached_keys()) == 6
......@@ -70,9 +70,11 @@ def cat(tensors, dim):
value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size,
colcount=torch.cat(colcounts) if has_colcount else None,
colptr=torch.cat(colptrs) if has_colptr else None, is_sorted=False)
colptr=torch.cat(colptrs) if has_colptr else None,
is_sorted=False,
)
elif dim == (0, 1) or (1, 0):
elif dim == (0, 1) or dim == (1, 0):
for tensor in tensors:
(row, col), value = tensor.coo()
rows += [row + sparse_size[0]]
......@@ -115,21 +117,26 @@ def cat(tensors, dim):
colptr=torch.cat(colptrs) if has_colptr else None,
csr2csc=torch.cat(csr2cscs) if has_csr2csc else None,
csc2csr=torch.cat(csc2csrs) if has_csc2csr else None,
is_sorted=True)
is_sorted=True,
)
elif isinstance(dim, int) and dim > 1 and dim < tensors[0].dim():
for tensor in tensors:
values += [tensor.storage.value]
sparse_size[0] = max(sparse_size[0], tensor.sparse_size(0))
sparse_size[1] = max(sparse_size[1], tensor.sparse_size(1))
old_storage = tensors[0].storage
storage = old_storage.storage.__class__(
tensors[0].storage.index, value=torch.cat(values, dim=dim - 1),
sparse_size=sparse_size, rowcount=old_storage._rowcount,
rowptr=old_storage._rowcount, colcount=old_storage._rowcount,
colptr=old_storage._rowcount, csr2csc=old_storage._csr2csc,
csc2csr=old_storage._csc2csr, is_sorted=True)
storage = old_storage.__class__(
tensors[0].storage.index,
value=torch.cat(values, dim=dim - 1),
sparse_size=old_storage.sparse_size(),
rowcount=old_storage._rowcount,
rowptr=old_storage._rowptr,
colcount=old_storage._colcount,
colptr=old_storage._colptr,
csr2csc=old_storage._csr2csc,
csc2csr=old_storage._csc2csr,
is_sorted=True,
)
else:
raise IndexError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment