Commit 872938af authored by rusty1s's avatar rusty1s
Browse files

overload for cat

parent 468aea5b
import pytest import pytest
import torch import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
from torch_sparse.cat import cat, cat_diag from torch_sparse.cat import cat
from .utils import devices, tensor from .utils import devices, tensor
...@@ -31,7 +31,7 @@ def test_cat(device): ...@@ -31,7 +31,7 @@ def test_cat(device):
assert not out.storage.has_rowptr() assert not out.storage.has_rowptr()
assert out.storage.num_cached_keys() == 2 assert out.storage.num_cached_keys() == 2
out = cat_diag([mat1, mat2]) out = cat([mat1, mat2], dim=(0, 1))
assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0], assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
[0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
[0, 0, 0, 1, 0]] [0, 0, 0, 1, 0]]
......
...@@ -44,7 +44,7 @@ from .add import add, add_, add_nnz, add_nnz_ # noqa ...@@ -44,7 +44,7 @@ from .add import add, add_, add_nnz, add_nnz_ # noqa
from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
from .reduce import sum, mean, min, max # noqa from .reduce import sum, mean, min, max # noqa
from .matmul import matmul # noqa from .matmul import matmul # noqa
from .cat import cat, cat_diag # noqa from .cat import cat # noqa
from .rw import random_walk # noqa from .rw import random_walk # noqa
from .metis import partition # noqa from .metis import partition # noqa
from .bandwidth import reverse_cuthill_mckee # noqa from .bandwidth import reverse_cuthill_mckee # noqa
...@@ -89,7 +89,6 @@ __all__ = [ ...@@ -89,7 +89,6 @@ __all__ = [
'max', 'max',
'matmul', 'matmul',
'cat', 'cat',
'cat_diag',
'random_walk', 'random_walk',
'partition', 'partition',
'reverse_cuthill_mckee', 'reverse_cuthill_mckee',
......
from typing import Optional, List from typing import Optional, List, Tuple
import torch import torch
from torch_sparse.storage import SparseStorage from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: @torch.jit._overload # noqa: F811
def cat(tensors, dim): # noqa: F811
# type: (List[SparseTensor], int) -> SparseTensor
pass
@torch.jit._overload # noqa: F811
def cat(tensors, dim): # noqa: F811
# type: (List[SparseTensor], Tuple[int, int]) -> SparseTensor
pass
@torch.jit._overload # noqa: F811
def cat(tensors, dim): # noqa: F811
# type: (List[SparseTensor], List[int]) -> SparseTensor
pass
def cat(tensors, dim): # noqa: F811
assert len(tensors) > 0 assert len(tensors) > 0
if dim < 0:
dim = tensors[0].dim() + dim if isinstance(dim, int):
dim = tensors[0].dim() + dim if dim < 0 else dim
if dim == 0:
rows: List[torch.Tensor] = [] if dim == 0:
rowptrs: List[torch.Tensor] = [] return cat_first(tensors)
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = [] elif dim == 1:
sparse_sizes: List[int] = [0, 0] return cat_second(tensors)
rowcounts: List[torch.Tensor] = [] pass
nnz: int = 0 elif dim > 1 and dim < tensors[0].dim():
for tensor in tensors: values = []
row = tensor.storage._row for tensor in tensors:
if row is not None: value = tensor.storage.value()
rows.append(row + sparse_sizes[0]) assert value is not None
rowptr = tensor.storage._rowptr
if rowptr is not None:
if len(rowptrs) > 0:
rowptr = rowptr[1:]
rowptrs.append(rowptr + nnz)
cols.append(tensor.storage._col)
value = tensor.storage._value
if value is not None:
values.append(value) values.append(value)
value = torch.cat(values, dim=dim - 1)
return tensors[0].set_value(value, layout='coo')
rowcount = tensor.storage._rowcount else:
if rowcount is not None: raise IndexError(
rowcounts.append(rowcount) (f'Dimension out of range: Expected to be in range of '
f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got '
f'{dim}.'))
else:
assert isinstance(dim, (tuple, list))
assert len(dim) == 2
assert sorted(dim) == [0, 1]
return cat_diag(tensors)
sparse_sizes[0] += tensor.sparse_size(0)
sparse_sizes[1] = max(sparse_sizes[1], tensor.sparse_size(1))
nnz += tensor.nnz()
row: Optional[torch.Tensor] = None def cat_first(tensors: List[SparseTensor]) -> SparseTensor:
if len(rows) == len(tensors): rows: List[torch.Tensor] = []
row = torch.cat(rows, dim=0) rowptrs: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
rowcounts: List[torch.Tensor] = []
rowptr: Optional[torch.Tensor] = None nnz: int = 0
if len(rowptrs) == len(tensors): for tensor in tensors:
rowptr = torch.cat(rowptrs, dim=0) row = tensor.storage._row
if row is not None:
rows.append(row + sparse_sizes[0])
col = torch.cat(cols, dim=0) rowptr = tensor.storage._rowptr
if rowptr is not None:
rowptrs.append(rowptr[1:] + nnz if len(rowptrs) > 0 else rowptr)
value: Optional[torch.Tensor] = None cols.append(tensor.storage._col)
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
rowcount: Optional[torch.Tensor] = None value = tensor.storage._value
if len(rowcounts) == len(tensors): if value is not None:
rowcount = torch.cat(rowcounts, dim=0) values.append(value)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, rowcount = tensor.storage._rowcount
sparse_sizes=sparse_sizes, rowcount=rowcount, if rowcount is not None:
colptr=None, colcount=None, csr2csc=None, rowcounts.append(rowcount)
csc2csr=None, is_sorted=True)
return tensors[0].from_storage(storage)
elif dim == 1: sparse_sizes[0] += tensor.sparse_size(0)
rows: List[torch.Tensor] = [] sparse_sizes[1] = max(sparse_sizes[1], tensor.sparse_size(1))
cols: List[torch.Tensor] = [] nnz += tensor.nnz()
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
colptrs: List[torch.Tensor] = []
colcounts: List[torch.Tensor] = []
nnz: int = 0 row: Optional[torch.Tensor] = None
for tensor in tensors: if len(rows) == len(tensors):
row, col, value = tensor.coo() row = torch.cat(rows, dim=0)
rows.append(row) rowptr: Optional[torch.Tensor] = None
if len(rowptrs) == len(tensors):
rowptr = torch.cat(rowptrs, dim=0)
cols.append(tensor.storage._col + sparse_sizes[1]) col = torch.cat(cols, dim=0)
if value is not None: value: Optional[torch.Tensor] = None
values.append(value) if len(values) == len(tensors):
value = torch.cat(values, dim=0)
colptr = tensor.storage._colptr rowcount: Optional[torch.Tensor] = None
if colptr is not None: if len(rowcounts) == len(tensors):
if len(colptrs) > 0: rowcount = torch.cat(rowcounts, dim=0)
colptr = colptr[1:]
colptrs.append(colptr + nnz)
colcount = tensor.storage._colcount storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
if colcount is not None: sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
colcounts.append(colcount) rowcount=rowcount, colptr=None, colcount=None,
csr2csc=None, csc2csr=None, is_sorted=True)
return tensors[0].from_storage(storage)
sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0))
sparse_sizes[1] += tensor.sparse_size(1)
nnz += tensor.nnz()
row = torch.cat(rows, dim=0) def cat_second(tensors: List[SparseTensor]) -> SparseTensor:
rows: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
colptrs: List[torch.Tensor] = []
colcounts: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors:
row, col, value = tensor.coo()
rows.append(row)
cols.append(tensor.storage._col + sparse_sizes[1])
if value is not None:
values.append(value)
col = torch.cat(cols, dim=0) colptr = tensor.storage._colptr
if colptr is not None:
colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr)
value: Optional[torch.Tensor] = None colcount = tensor.storage._colcount
if len(values) == len(tensors): if colcount is not None:
value = torch.cat(values, dim=0) colcounts.append(colcount)
colptr: Optional[torch.Tensor] = None sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0))
if len(colptrs) == len(tensors): sparse_sizes[1] += tensor.sparse_size(1)
colptr = torch.cat(colptrs, dim=0) nnz += tensor.nnz()
colcount: Optional[torch.Tensor] = None row = torch.cat(rows, dim=0)
if len(colcounts) == len(tensors): col = torch.cat(cols, dim=0)
colcount = torch.cat(colcounts, dim=0)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value, value: Optional[torch.Tensor] = None
sparse_sizes=sparse_sizes, rowcount=None, if len(values) == len(tensors):
colptr=colptr, colcount=colcount, csr2csc=None, value = torch.cat(values, dim=0)
csc2csr=None, is_sorted=False)
return tensors[0].from_storage(storage)
elif dim > 1 and dim < tensors[0].dim(): colptr: Optional[torch.Tensor] = None
values: List[torch.Tensor] = [] if len(colptrs) == len(tensors):
for tensor in tensors: colptr = torch.cat(colptrs, dim=0)
value = tensor.storage.value()
if value is not None:
values.append(value)
value: Optional[torch.Tensor] = None colcount: Optional[torch.Tensor] = None
if len(values) == len(tensors): if len(colcounts) == len(tensors):
value = torch.cat(values, dim=dim - 1) colcount = torch.cat(colcounts, dim=0)
return tensors[0].set_value(value, layout='coo') storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
else: sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
raise IndexError( rowcount=None, colptr=colptr, colcount=colcount,
(f'Dimension out of range: Expected to be in range of ' csr2csc=None, csc2csr=None, is_sorted=False)
f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.')) return tensors[0].from_storage(storage)
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor: def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
...@@ -163,9 +187,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor: ...@@ -163,9 +187,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
rowptr = tensor.storage._rowptr rowptr = tensor.storage._rowptr
if rowptr is not None: if rowptr is not None:
if len(rowptrs) > 0: rowptrs.append(rowptr[1:] + nnz if len(rowptrs) > 0 else rowptr)
rowptr = rowptr[1:]
rowptrs.append(rowptr + nnz)
cols.append(tensor.storage._col + sparse_sizes[1]) cols.append(tensor.storage._col + sparse_sizes[1])
...@@ -179,9 +201,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor: ...@@ -179,9 +201,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
colptr = tensor.storage._colptr colptr = tensor.storage._colptr
if colptr is not None: if colptr is not None:
if len(colptrs) > 0: colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr)
colptr = colptr[1:]
colptrs.append(colptr + nnz)
colcount = tensor.storage._colcount colcount = tensor.storage._colcount
if colcount is not None: if colcount is not None:
...@@ -234,7 +254,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor: ...@@ -234,7 +254,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
csc2csr = torch.cat(csc2csrs, dim=0) csc2csr = torch.cat(csc2csrs, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount, sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
colptr=colptr, colcount=colcount, csr2csc=csr2csc, rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True) csc2csr=csc2csr, is_sorted=True)
return tensors[0].from_storage(storage) return tensors[0].from_storage(storage)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment