Commit 872938af authored by rusty1s's avatar rusty1s
Browse files

overload for cat

parent 468aea5b
import pytest import pytest
import torch import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
from torch_sparse.cat import cat, cat_diag from torch_sparse.cat import cat
from .utils import devices, tensor from .utils import devices, tensor
...@@ -31,7 +31,7 @@ def test_cat(device): ...@@ -31,7 +31,7 @@ def test_cat(device):
assert not out.storage.has_rowptr() assert not out.storage.has_rowptr()
assert out.storage.num_cached_keys() == 2 assert out.storage.num_cached_keys() == 2
out = cat_diag([mat1, mat2]) out = cat([mat1, mat2], dim=(0, 1))
assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0], assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
[0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
[0, 0, 0, 1, 0]] [0, 0, 0, 1, 0]]
......
...@@ -44,7 +44,7 @@ from .add import add, add_, add_nnz, add_nnz_ # noqa ...@@ -44,7 +44,7 @@ from .add import add, add_, add_nnz, add_nnz_ # noqa
from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
from .reduce import sum, mean, min, max # noqa from .reduce import sum, mean, min, max # noqa
from .matmul import matmul # noqa from .matmul import matmul # noqa
from .cat import cat, cat_diag # noqa from .cat import cat # noqa
from .rw import random_walk # noqa from .rw import random_walk # noqa
from .metis import partition # noqa from .metis import partition # noqa
from .bandwidth import reverse_cuthill_mckee # noqa from .bandwidth import reverse_cuthill_mckee # noqa
...@@ -89,7 +89,6 @@ __all__ = [ ...@@ -89,7 +89,6 @@ __all__ = [
'max', 'max',
'matmul', 'matmul',
'cat', 'cat',
'cat_diag',
'random_walk', 'random_walk',
'partition', 'partition',
'reverse_cuthill_mckee', 'reverse_cuthill_mckee',
......
from typing import Optional, List from typing import Optional, List, Tuple
import torch import torch
from torch_sparse.storage import SparseStorage from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: @torch.jit._overload # noqa: F811
def cat(tensors, dim): # noqa: F811
# type: (List[SparseTensor], int) -> SparseTensor
pass
@torch.jit._overload # noqa: F811
def cat(tensors, dim): # noqa: F811
# type: (List[SparseTensor], Tuple[int, int]) -> SparseTensor
pass
@torch.jit._overload # noqa: F811
def cat(tensors, dim): # noqa: F811
# type: (List[SparseTensor], List[int]) -> SparseTensor
pass
def cat(tensors, dim): # noqa: F811
assert len(tensors) > 0 assert len(tensors) > 0
if dim < 0:
dim = tensors[0].dim() + dim if isinstance(dim, int):
dim = tensors[0].dim() + dim if dim < 0 else dim
if dim == 0: if dim == 0:
return cat_first(tensors)
elif dim == 1:
return cat_second(tensors)
pass
elif dim > 1 and dim < tensors[0].dim():
values = []
for tensor in tensors:
value = tensor.storage.value()
assert value is not None
values.append(value)
value = torch.cat(values, dim=dim - 1)
return tensors[0].set_value(value, layout='coo')
else:
raise IndexError(
(f'Dimension out of range: Expected to be in range of '
f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got '
f'{dim}.'))
else:
assert isinstance(dim, (tuple, list))
assert len(dim) == 2
assert sorted(dim) == [0, 1]
return cat_diag(tensors)
def cat_first(tensors: List[SparseTensor]) -> SparseTensor:
rows: List[torch.Tensor] = [] rows: List[torch.Tensor] = []
rowptrs: List[torch.Tensor] = [] rowptrs: List[torch.Tensor] = []
cols: List[torch.Tensor] = [] cols: List[torch.Tensor] = []
...@@ -26,9 +73,7 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: ...@@ -26,9 +73,7 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
rowptr = tensor.storage._rowptr rowptr = tensor.storage._rowptr
if rowptr is not None: if rowptr is not None:
if len(rowptrs) > 0: rowptrs.append(rowptr[1:] + nnz if len(rowptrs) > 0 else rowptr)
rowptr = rowptr[1:]
rowptrs.append(rowptr + nnz)
cols.append(tensor.storage._col) cols.append(tensor.storage._col)
...@@ -63,12 +108,13 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: ...@@ -63,12 +108,13 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
rowcount = torch.cat(rowcounts, dim=0) rowcount = torch.cat(rowcounts, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount, sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
colptr=None, colcount=None, csr2csc=None, rowcount=rowcount, colptr=None, colcount=None,
csc2csr=None, is_sorted=True) csr2csc=None, csc2csr=None, is_sorted=True)
return tensors[0].from_storage(storage) return tensors[0].from_storage(storage)
elif dim == 1:
def cat_second(tensors: List[SparseTensor]) -> SparseTensor:
rows: List[torch.Tensor] = [] rows: List[torch.Tensor] = []
cols: List[torch.Tensor] = [] cols: List[torch.Tensor] = []
values: List[torch.Tensor] = [] values: List[torch.Tensor] = []
...@@ -79,9 +125,7 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: ...@@ -79,9 +125,7 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
nnz: int = 0 nnz: int = 0
for tensor in tensors: for tensor in tensors:
row, col, value = tensor.coo() row, col, value = tensor.coo()
rows.append(row) rows.append(row)
cols.append(tensor.storage._col + sparse_sizes[1]) cols.append(tensor.storage._col + sparse_sizes[1])
if value is not None: if value is not None:
...@@ -89,9 +133,7 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: ...@@ -89,9 +133,7 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
colptr = tensor.storage._colptr colptr = tensor.storage._colptr
if colptr is not None: if colptr is not None:
if len(colptrs) > 0: colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr)
colptr = colptr[1:]
colptrs.append(colptr + nnz)
colcount = tensor.storage._colcount colcount = tensor.storage._colcount
if colcount is not None: if colcount is not None:
...@@ -102,7 +144,6 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: ...@@ -102,7 +144,6 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
nnz += tensor.nnz() nnz += tensor.nnz()
row = torch.cat(rows, dim=0) row = torch.cat(rows, dim=0)
col = torch.cat(cols, dim=0) col = torch.cat(cols, dim=0)
value: Optional[torch.Tensor] = None value: Optional[torch.Tensor] = None
...@@ -118,28 +159,11 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: ...@@ -118,28 +159,11 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
colcount = torch.cat(colcounts, dim=0) colcount = torch.cat(colcounts, dim=0)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value, storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None, sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
colptr=colptr, colcount=colcount, csr2csc=None, rowcount=None, colptr=colptr, colcount=colcount,
csc2csr=None, is_sorted=False) csr2csc=None, csc2csr=None, is_sorted=False)
return tensors[0].from_storage(storage) return tensors[0].from_storage(storage)
elif dim > 1 and dim < tensors[0].dim():
values: List[torch.Tensor] = []
for tensor in tensors:
value = tensor.storage.value()
if value is not None:
values.append(value)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=dim - 1)
return tensors[0].set_value(value, layout='coo')
else:
raise IndexError(
(f'Dimension out of range: Expected to be in range of '
f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.'))
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor: def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
assert len(tensors) > 0 assert len(tensors) > 0
...@@ -163,9 +187,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor: ...@@ -163,9 +187,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
rowptr = tensor.storage._rowptr rowptr = tensor.storage._rowptr
if rowptr is not None: if rowptr is not None:
if len(rowptrs) > 0: rowptrs.append(rowptr[1:] + nnz if len(rowptrs) > 0 else rowptr)
rowptr = rowptr[1:]
rowptrs.append(rowptr + nnz)
cols.append(tensor.storage._col + sparse_sizes[1]) cols.append(tensor.storage._col + sparse_sizes[1])
...@@ -179,9 +201,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor: ...@@ -179,9 +201,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
colptr = tensor.storage._colptr colptr = tensor.storage._colptr
if colptr is not None: if colptr is not None:
if len(colptrs) > 0: colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr)
colptr = colptr[1:]
colptrs.append(colptr + nnz)
colcount = tensor.storage._colcount colcount = tensor.storage._colcount
if colcount is not None: if colcount is not None:
...@@ -234,7 +254,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor: ...@@ -234,7 +254,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
csc2csr = torch.cat(csc2csrs, dim=0) csc2csr = torch.cat(csc2csrs, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount, sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
colptr=colptr, colcount=colcount, csr2csc=csr2csc, rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True) csc2csr=csc2csr, is_sorted=True)
return tensors[0].from_storage(storage) return tensors[0].from_storage(storage)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment