"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "a54978bbf4dfc3259afe8eb4423d53b4b62dabd8"
Commit 0e2ddfad authored by rusty1s's avatar rusty1s
Browse files

added view to storage + rename

parent 4dec4df0
...@@ -122,3 +122,24 @@ def test_coalesce(dtype, device): ...@@ -122,3 +122,24 @@ def test_coalesce(dtype, device):
assert storage.row().tolist() == [0, 0, 1, 1] assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 1, 0, 1] assert storage.col().tolist() == [0, 1, 0, 1]
assert storage.value().tolist() == [1, 2, 3, 4] assert storage.value().tolist() == [1, 2, 3, 4]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_sparse_reshape(dtype, device):
row, col = tensor([[0, 1, 2, 3], [0, 1, 2, 3]], torch.long, device)
storage = SparseStorage(row=row, col=col)
storage = storage.sparse_reshape(2, 8)
assert storage.sparse_sizes() == (2, 8)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 5, 2, 7]
storage = storage.sparse_reshape(-1, 4)
assert storage.sparse_sizes() == (4, 4)
assert storage.row().tolist() == [0, 1, 2, 3]
assert storage.col().tolist() == [0, 1, 2, 3]
storage = storage.sparse_reshape(2, -1)
assert storage.sparse_sizes() == (2, 8)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 5, 2, 7]
from itertools import product
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse import view
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_view_matrix(dtype, device):
row = torch.tensor([0, 1, 1], device=device)
col = torch.tensor([1, 0, 2], device=device)
index = torch.stack([row, col], dim=0)
value = tensor([1, 2, 3], dtype, device)
index, value = view(index, value, m=2, n=3, new_n=2)
assert index.tolist() == [[0, 1, 2], [1, 1, 1]]
assert value.tolist() == [1, 2, 3]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_view_sparse_tensor(dtype, device):
options = torch.tensor(0, dtype=dtype, device=device)
mat = SparseTensor.eye(4, options=options).view(2, 8)
assert mat.storage.sparse_sizes() == (2, 8)
assert mat.storage.row().tolist() == [0, 0, 1, 1]
assert mat.storage.col().tolist() == [0, 5, 2, 7]
assert mat.storage.value().tolist() == [1, 1, 1, 1]
...@@ -55,7 +55,6 @@ from .convert import to_torch_sparse, from_torch_sparse # noqa ...@@ -55,7 +55,6 @@ from .convert import to_torch_sparse, from_torch_sparse # noqa
from .convert import to_scipy, from_scipy # noqa from .convert import to_scipy, from_scipy # noqa
from .coalesce import coalesce # noqa from .coalesce import coalesce # noqa
from .transpose import transpose # noqa from .transpose import transpose # noqa
from .view import view # noqa
from .eye import eye # noqa from .eye import eye # noqa
from .spmm import spmm # noqa from .spmm import spmm # noqa
from .spspmm import spspmm # noqa from .spspmm import spspmm # noqa
...@@ -102,7 +101,6 @@ __all__ = [ ...@@ -102,7 +101,6 @@ __all__ = [
'from_scipy', 'from_scipy',
'coalesce', 'coalesce',
'transpose', 'transpose',
'view',
'eye', 'eye',
'spmm', 'spmm',
'spspmm', 'spspmm',
......
...@@ -260,6 +260,31 @@ class SparseStorage(object): ...@@ -260,6 +260,31 @@ class SparseStorage(object):
colcount=colcount, csr2csc=self._csr2csc, colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True) csc2csr=self._csc2csr, is_sorted=True)
def sparse_reshape(self, num_rows: int, num_cols: int):
assert num_rows > 0 or num_rows == -1
assert num_cols > 0 or num_cols == -1
assert num_rows > 0 or num_cols > 0
total = self.sparse_size(0) * self.sparse_size(1)
if num_rows == -1:
num_rows = total // num_cols
if num_cols == -1:
num_cols = total // num_rows
assert num_rows * num_cols == total
idx = self.sparse_size(1) * self.row() + self.col()
row = idx / num_cols
col = idx % num_cols
return SparseStorage(row=row, rowptr=None, col=col, value=self._value,
sparse_sizes=(num_rows, num_cols), rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
def has_rowcount(self) -> bool: def has_rowcount(self) -> bool:
return self._rowcount is not None return self._rowcount is not None
......
...@@ -171,6 +171,10 @@ class SparseTensor(object): ...@@ -171,6 +171,10 @@ class SparseTensor(object):
def sparse_resize(self, sparse_sizes: Tuple[int, int]): def sparse_resize(self, sparse_sizes: Tuple[int, int]):
return self.from_storage(self.storage.sparse_resize(sparse_sizes)) return self.from_storage(self.storage.sparse_resize(sparse_sizes))
def sparse_reshape(self, num_rows: int, num_cols: int):
return self.from_storage(
self.storage.sparse_reshape(num_rows, num_cols))
def is_coalesced(self) -> bool: def is_coalesced(self) -> bool:
return self.storage.is_coalesced() return self.storage.is_coalesced()
......
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def _view(src: SparseTensor, n: int, layout: str = 'csr') -> SparseTensor:
row, col, value = src.coo()
sparse_sizes = src.storage.sparse_sizes()
if sparse_sizes[0] * sparse_sizes[1] % n != 0:
raise RuntimeError(
f"shape '[-1, {n}]' is invalid for input of size "
f"{sparse_sizes[0] * sparse_sizes[1]}")
assert layout == 'csr' or layout == 'csc'
if layout == 'csr':
idx = sparse_sizes[1] * row + col
row = idx // n
col = idx % n
sparse_sizes = (sparse_sizes[0] * sparse_sizes[1] // n, n)
if layout == 'csc':
idx = sparse_sizes[0] * col + row
row = idx % n
col = idx // n
sparse_sizes = (n, sparse_sizes[0] * sparse_sizes[1] // n)
storage = SparseStorage(
row=row,
col=col,
value=value,
sparse_sizes=sparse_sizes,
csr2csc=src.storage._csr2csc,
csc2csr=src.storage._csc2csr,
is_sorted=True,
)
return src.from_storage(storage)
SparseTensor.view = lambda self, m, n: _view(self, n, layout='csr')
###############################################################################
def view(index, value, m, n, new_n):
assert m * n % new_n == 0
row, col = index
idx = n * row + col
row = idx // new_n
col = idx % new_n
return torch.stack([row, col], dim=0), value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment