Commit 64a8e2ce authored by Mario Geiger's avatar Mario Geiger
Browse files

view

parent 57852a66
from itertools import product
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse import view
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_view_matrix(dtype, device):
row = torch.tensor([0, 1, 1], device=device)
col = torch.tensor([1, 0, 2], device=device)
index = torch.stack([row, col], dim=0)
value = tensor([1, 2, 3], dtype, device)
index, value = view(index, value, m=2, n=3, new_n=2)
assert index.tolist() == [[0, 1, 2], [1, 1, 1]]
assert value.tolist() == [1, 2, 3]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_view_sparse_tensor(dtype, device):
options = torch.tensor(0, dtype=dtype, device=device)
mat = SparseTensor.eye(4, options=options).view(2, 8)
assert mat.storage.sparse_sizes() == (2, 8)
assert mat.storage.row().tolist() == [0, 0, 1, 1]
assert mat.storage.col().tolist() == [0, 5, 2, 7]
assert mat.storage.value().tolist() == [1, 1, 1, 1]
...@@ -55,6 +55,7 @@ from .convert import to_torch_sparse, from_torch_sparse # noqa ...@@ -55,6 +55,7 @@ from .convert import to_torch_sparse, from_torch_sparse # noqa
from .convert import to_scipy, from_scipy # noqa from .convert import to_scipy, from_scipy # noqa
from .coalesce import coalesce # noqa from .coalesce import coalesce # noqa
from .transpose import transpose # noqa from .transpose import transpose # noqa
from .view import view # noqa
from .eye import eye # noqa from .eye import eye # noqa
from .spmm import spmm # noqa from .spmm import spmm # noqa
from .spspmm import spspmm # noqa from .spspmm import spspmm # noqa
...@@ -101,6 +102,7 @@ __all__ = [ ...@@ -101,6 +102,7 @@ __all__ = [
'from_scipy', 'from_scipy',
'coalesce', 'coalesce',
'transpose', 'transpose',
'view',
'eye', 'eye',
'spmm', 'spmm',
'spspmm', 'spspmm',
......
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def _view(src: SparseTensor, n: int, layout: str = 'csr') -> SparseTensor:
row, col, value = src.coo()
sparse_sizes = src.storage.sparse_sizes()
if sparse_sizes[0] * sparse_sizes[1] % n == 0:
raise RuntimeError(
f"shape '[-1, {n}]' is invalid for input of size {sparse_sizes[0] * sparse_sizes[1]}")
assert layout == 'csr' or layout == 'csc'
if layout == 'csr':
idx = sparse_sizes[1] * row + col
row = idx // n
col = idx % n
sparse_sizes = (sparse_sizes[0] * sparse_sizes[1] // n, n)
if layout == 'csc':
idx = sparse_sizes[0] * col + row
row = idx % n
col = idx // n
sparse_sizes = (n, sparse_sizes[0] * sparse_sizes[1] // n)
storage = SparseStorage(
row=row,
rowptr=src.storage._rowptr,
col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=src.storage._rowcount,
colptr=src.storage._colptr,
colcount=src.storage._colcount,
csr2csc=src.storage._csr2csc,
csc2csr=src.storage._csc2csr,
is_sorted=True,
)
return src.from_storage(storage)
SparseTensor.view = lambda self, m, n: _view(self, n, layout='csr')
###############################################################################
def view(index, value, m, n, new_n):
assert m * n % new_n == 0
row, col = index
idx = n * row + col
row = idx // new_n
col = idx % new_n
return torch.stack([row, col], dim=0), value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment