"torchvision/csrc/io/decoder/seekable_buffer.cpp" did not exist on "67f5fcf7099aa0857230995277c264d66d2fc0ab"
Commit d30ed1d5 authored by rusty1s's avatar rusty1s
Browse files

cat

parent c2bd0b10
from itertools import product
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.cat import cat
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_cat(dtype, device):
index = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
mat1 = SparseTensor(index)
index = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
mat2 = SparseTensor(index)
cat([mat1, mat2], dim=(0, 1))
import torch
def cat(tensors, dim):
assert len(tensors) > 0
has_value = tensors[0].has_value()
has_rowcount = tensors[0].storage._rowcount is not None
has_rowptr = tensors[0].storage._rowptr is not None
has_colcount = tensors[0].storage._colcount is not None
has_colptr = tensors[0].storage._colptr is not None
has_csr2csc = tensors[0].storage._csr2csc is not None
has_csc2csr = tensors[0].storage._csc2csr is not None
rows, cols, values, sparse_size = [], [], [], [0, 0]
rowcounts, rowptrs, colcounts, colptrs = [], [], [], []
csr2cscs, csc2csrs, nnzs = [], [], 0
if dim == 0:
for tensor in tensors:
(row, col), value = tensor.coo()
rows += [row + sparse_size[0]]
cols += [col]
values += [value] if has_value else []
sparse_size[0] += tensor.sparse_size(0)
sparse_size[1] = max(sparse_size[1], tensor.sparse_size(1))
rowcounts += [tensor.storage.rowcount] if has_rowcount else []
if has_rowptr:
rowptr = tensor.storage.rowptr
rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
rowptrs += [rowptr + nnzs]
nnzs += tensor.nnz()
storage = tensors[0].storage.__class__(
torch.stack([torch.cat(rows), torch.cat(cols)], dim=0),
value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size,
rowcount=torch.cat(rowcounts) if has_rowcount else None,
rowptr=torch.cat(rowptrs) if has_rowptr else None,
is_sorted=True)
if dim == 1:
raise NotImplementedError
if dim == (0, 1) or (1, 0):
for tensor in tensors:
(row, col), value = tensor.coo()
rows += [row + sparse_size[0]]
cols += [col + sparse_size[1]]
values += [value] if has_value else []
sparse_size[0] += tensor.sparse_size(0)
sparse_size[1] += tensor.sparse_size(1)
rowcounts += [tensor.storage.rowcount] if has_rowcount else []
colcounts += [tensor.storage.colcount] if has_colcount else []
if has_rowptr:
rowptr = tensor.storage.rowptr
rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
rowptrs += [rowptr + nnzs]
if has_colptr:
colptr = tensor.storage.colptr
colptr = colptr if len(colptrs) == 0 else colptr[1:]
colptrs += [colptr + nnzs]
csr2cscs += [tensor.storage.csr2csc + nnzs] if has_csr2csc else []
csc2csrs += [tensor.storage.csc2csr + nnzs] if has_csc2csr else []
nnzs += tensor.nnz()
storage = tensors[0].storage.__class__(
torch.stack([torch.cat(rows), torch.cat(cols)], dim=0),
value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size,
rowcount=torch.cat(rowcounts) if has_rowcount else None,
rowptr=torch.cat(rowptrs) if has_rowptr else None,
colcount=torch.cat(colcounts) if has_colcount else None,
colptr=torch.cat(colptrs) if has_colptr else None,
csr2csc=torch.cat(csr2cscs) if has_csr2csc else None,
csc2csr=torch.cat(csc2csrs) if has_csc2csr else None,
is_sorted=True)
else:
raise NotImplementedError
return tensors[0].__class__.from_storage(storage)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment