Unverified Commit b1853b60 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #61 from mariogeiger/view

view
parents 93540a36 0e2ddfad
...@@ -122,3 +122,24 @@ def test_coalesce(dtype, device): ...@@ -122,3 +122,24 @@ def test_coalesce(dtype, device):
assert storage.row().tolist() == [0, 0, 1, 1] assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 1, 0, 1] assert storage.col().tolist() == [0, 1, 0, 1]
assert storage.value().tolist() == [1, 2, 3, 4] assert storage.value().tolist() == [1, 2, 3, 4]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_sparse_reshape(dtype, device):
row, col = tensor([[0, 1, 2, 3], [0, 1, 2, 3]], torch.long, device)
storage = SparseStorage(row=row, col=col)
storage = storage.sparse_reshape(2, 8)
assert storage.sparse_sizes() == (2, 8)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 5, 2, 7]
storage = storage.sparse_reshape(-1, 4)
assert storage.sparse_sizes() == (4, 4)
assert storage.row().tolist() == [0, 1, 2, 3]
assert storage.col().tolist() == [0, 1, 2, 3]
storage = storage.sparse_reshape(2, -1)
assert storage.sparse_sizes() == (2, 8)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 5, 2, 7]
...@@ -260,6 +260,31 @@ class SparseStorage(object): ...@@ -260,6 +260,31 @@ class SparseStorage(object):
colcount=colcount, csr2csc=self._csr2csc, colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True) csc2csr=self._csc2csr, is_sorted=True)
def sparse_reshape(self, num_rows: int, num_cols: int):
assert num_rows > 0 or num_rows == -1
assert num_cols > 0 or num_cols == -1
assert num_rows > 0 or num_cols > 0
total = self.sparse_size(0) * self.sparse_size(1)
if num_rows == -1:
num_rows = total // num_cols
if num_cols == -1:
num_cols = total // num_rows
assert num_rows * num_cols == total
idx = self.sparse_size(1) * self.row() + self.col()
row = idx / num_cols
col = idx % num_cols
return SparseStorage(row=row, rowptr=None, col=col, value=self._value,
sparse_sizes=(num_rows, num_cols), rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
def has_rowcount(self) -> bool: def has_rowcount(self) -> bool:
return self._rowcount is not None return self._rowcount is not None
......
...@@ -171,6 +171,10 @@ class SparseTensor(object): ...@@ -171,6 +171,10 @@ class SparseTensor(object):
def sparse_resize(self, sparse_sizes: Tuple[int, int]): def sparse_resize(self, sparse_sizes: Tuple[int, int]):
return self.from_storage(self.storage.sparse_resize(sparse_sizes)) return self.from_storage(self.storage.sparse_resize(sparse_sizes))
def sparse_reshape(self, num_rows: int, num_cols: int):
return self.from_storage(
self.storage.sparse_reshape(num_rows, num_cols))
def is_coalesced(self) -> bool: def is_coalesced(self) -> bool:
return self.storage.is_coalesced() return self.storage.is_coalesced()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment