Commit 872938af authored by rusty1s's avatar rusty1s
Browse files

overload for cat

parent 468aea5b
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.cat import cat, cat_diag
from torch_sparse.cat import cat
from .utils import devices, tensor
......@@ -31,7 +31,7 @@ def test_cat(device):
assert not out.storage.has_rowptr()
assert out.storage.num_cached_keys() == 2
out = cat_diag([mat1, mat2])
out = cat([mat1, mat2], dim=(0, 1))
assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
[0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
[0, 0, 0, 1, 0]]
......
......@@ -44,7 +44,7 @@ from .add import add, add_, add_nnz, add_nnz_ # noqa
from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
from .reduce import sum, mean, min, max # noqa
from .matmul import matmul # noqa
from .cat import cat, cat_diag # noqa
from .cat import cat # noqa
from .rw import random_walk # noqa
from .metis import partition # noqa
from .bandwidth import reverse_cuthill_mckee # noqa
......@@ -89,7 +89,6 @@ __all__ = [
'max',
'matmul',
'cat',
'cat_diag',
'random_walk',
'partition',
'reverse_cuthill_mckee',
......
from typing import Optional, List
from typing import Optional, List, Tuple
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
@torch.jit._overload # noqa: F811
def cat(tensors, dim): # noqa: F811
# type: (List[SparseTensor], int) -> SparseTensor
pass
@torch.jit._overload # noqa: F811
def cat(tensors, dim): # noqa: F811
# type: (List[SparseTensor], Tuple[int, int]) -> SparseTensor
pass
@torch.jit._overload # noqa: F811
def cat(tensors, dim): # noqa: F811
# type: (List[SparseTensor], List[int]) -> SparseTensor
pass
def cat(tensors, dim): # noqa: F811
assert len(tensors) > 0
if dim < 0:
dim = tensors[0].dim() + dim
if dim == 0:
rows: List[torch.Tensor] = []
rowptrs: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
rowcounts: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors:
row = tensor.storage._row
if row is not None:
rows.append(row + sparse_sizes[0])
rowptr = tensor.storage._rowptr
if rowptr is not None:
if len(rowptrs) > 0:
rowptr = rowptr[1:]
rowptrs.append(rowptr + nnz)
cols.append(tensor.storage._col)
value = tensor.storage._value
if value is not None:
if isinstance(dim, int):
dim = tensors[0].dim() + dim if dim < 0 else dim
if dim == 0:
return cat_first(tensors)
elif dim == 1:
return cat_second(tensors)
pass
elif dim > 1 and dim < tensors[0].dim():
values = []
for tensor in tensors:
value = tensor.storage.value()
assert value is not None
values.append(value)
value = torch.cat(values, dim=dim - 1)
return tensors[0].set_value(value, layout='coo')
rowcount = tensor.storage._rowcount
if rowcount is not None:
rowcounts.append(rowcount)
else:
raise IndexError(
(f'Dimension out of range: Expected to be in range of '
f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got '
f'{dim}.'))
else:
assert isinstance(dim, (tuple, list))
assert len(dim) == 2
assert sorted(dim) == [0, 1]
return cat_diag(tensors)
sparse_sizes[0] += tensor.sparse_size(0)
sparse_sizes[1] = max(sparse_sizes[1], tensor.sparse_size(1))
nnz += tensor.nnz()
row: Optional[torch.Tensor] = None
if len(rows) == len(tensors):
row = torch.cat(rows, dim=0)
def cat_first(tensors: List[SparseTensor]) -> SparseTensor:
rows: List[torch.Tensor] = []
rowptrs: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
rowcounts: List[torch.Tensor] = []
rowptr: Optional[torch.Tensor] = None
if len(rowptrs) == len(tensors):
rowptr = torch.cat(rowptrs, dim=0)
nnz: int = 0
for tensor in tensors:
row = tensor.storage._row
if row is not None:
rows.append(row + sparse_sizes[0])
col = torch.cat(cols, dim=0)
rowptr = tensor.storage._rowptr
if rowptr is not None:
rowptrs.append(rowptr[1:] + nnz if len(rowptrs) > 0 else rowptr)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
cols.append(tensor.storage._col)
rowcount: Optional[torch.Tensor] = None
if len(rowcounts) == len(tensors):
rowcount = torch.cat(rowcounts, dim=0)
value = tensor.storage._value
if value is not None:
values.append(value)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return tensors[0].from_storage(storage)
rowcount = tensor.storage._rowcount
if rowcount is not None:
rowcounts.append(rowcount)
elif dim == 1:
rows: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
colptrs: List[torch.Tensor] = []
colcounts: List[torch.Tensor] = []
sparse_sizes[0] += tensor.sparse_size(0)
sparse_sizes[1] = max(sparse_sizes[1], tensor.sparse_size(1))
nnz += tensor.nnz()
nnz: int = 0
for tensor in tensors:
row, col, value = tensor.coo()
row: Optional[torch.Tensor] = None
if len(rows) == len(tensors):
row = torch.cat(rows, dim=0)
rows.append(row)
rowptr: Optional[torch.Tensor] = None
if len(rowptrs) == len(tensors):
rowptr = torch.cat(rowptrs, dim=0)
cols.append(tensor.storage._col + sparse_sizes[1])
col = torch.cat(cols, dim=0)
if value is not None:
values.append(value)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
colptr = tensor.storage._colptr
if colptr is not None:
if len(colptrs) > 0:
colptr = colptr[1:]
colptrs.append(colptr + nnz)
rowcount: Optional[torch.Tensor] = None
if len(rowcounts) == len(tensors):
rowcount = torch.cat(rowcounts, dim=0)
colcount = tensor.storage._colcount
if colcount is not None:
colcounts.append(colcount)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
rowcount=rowcount, colptr=None, colcount=None,
csr2csc=None, csc2csr=None, is_sorted=True)
return tensors[0].from_storage(storage)
sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0))
sparse_sizes[1] += tensor.sparse_size(1)
nnz += tensor.nnz()
row = torch.cat(rows, dim=0)
def cat_second(tensors: List[SparseTensor]) -> SparseTensor:
rows: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
colptrs: List[torch.Tensor] = []
colcounts: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors:
row, col, value = tensor.coo()
rows.append(row)
cols.append(tensor.storage._col + sparse_sizes[1])
if value is not None:
values.append(value)
col = torch.cat(cols, dim=0)
colptr = tensor.storage._colptr
if colptr is not None:
colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
colcount = tensor.storage._colcount
if colcount is not None:
colcounts.append(colcount)
colptr: Optional[torch.Tensor] = None
if len(colptrs) == len(tensors):
colptr = torch.cat(colptrs, dim=0)
sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0))
sparse_sizes[1] += tensor.sparse_size(1)
nnz += tensor.nnz()
colcount: Optional[torch.Tensor] = None
if len(colcounts) == len(tensors):
colcount = torch.cat(colcounts, dim=0)
row = torch.cat(rows, dim=0)
col = torch.cat(cols, dim=0)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=False)
return tensors[0].from_storage(storage)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
elif dim > 1 and dim < tensors[0].dim():
values: List[torch.Tensor] = []
for tensor in tensors:
value = tensor.storage.value()
if value is not None:
values.append(value)
colptr: Optional[torch.Tensor] = None
if len(colptrs) == len(tensors):
colptr = torch.cat(colptrs, dim=0)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=dim - 1)
colcount: Optional[torch.Tensor] = None
if len(colcounts) == len(tensors):
colcount = torch.cat(colcounts, dim=0)
return tensors[0].set_value(value, layout='coo')
else:
raise IndexError(
(f'Dimension out of range: Expected to be in range of '
f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.'))
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
rowcount=None, colptr=colptr, colcount=colcount,
csr2csc=None, csc2csr=None, is_sorted=False)
return tensors[0].from_storage(storage)
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
......@@ -163,9 +187,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
rowptr = tensor.storage._rowptr
if rowptr is not None:
if len(rowptrs) > 0:
rowptr = rowptr[1:]
rowptrs.append(rowptr + nnz)
rowptrs.append(rowptr[1:] + nnz if len(rowptrs) > 0 else rowptr)
cols.append(tensor.storage._col + sparse_sizes[1])
......@@ -179,9 +201,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
colptr = tensor.storage._colptr
if colptr is not None:
if len(colptrs) > 0:
colptr = colptr[1:]
colptrs.append(colptr + nnz)
colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr)
colcount = tensor.storage._colcount
if colcount is not None:
......@@ -234,7 +254,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
csc2csr = torch.cat(csc2csrs, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=colptr, colcount=colcount, csr2csc=csr2csc,
sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return tensors[0].from_storage(storage)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment