Commit 6e9cd9d6 authored by rusty1s's avatar rusty1s
Browse files

sparse size to tuple

parent e6f5c3f0
......@@ -93,8 +93,8 @@ def test_utility(dtype, device):
storage = storage.set_value(value, layout='coo')
assert storage.value().tolist() == [1, 2, 3, 4]
storage = storage.sparse_resize([3, 3])
assert storage.sparse_sizes() == [3, 3]
storage = storage.sparse_resize((3, 3))
assert storage.sparse_sizes() == (3, 3)
new_storage = storage.copy()
assert new_storage != storage
......
......@@ -20,6 +20,6 @@ def coalesce(index, value, m, n, op="add"):
"""
storage = SparseStorage(row=index[0], col=index[1], value=value,
sparse_sizes=torch.Size([m, n]), is_sorted=False)
sparse_sizes=(m, n), is_sorted=False)
storage = storage.coalesce(reduce=op)
return torch.stack([storage.row(), storage.col()], dim=0), storage.value()
......@@ -5,7 +5,7 @@ from torch import from_numpy
def to_torch_sparse(index, value, m, n):
return torch.sparse_coo_tensor(index.detach(), value, torch.Size([m, n]))
return torch.sparse_coo_tensor(index.detach(), value, (m, n))
def from_torch_sparse(A):
......
......@@ -31,7 +31,7 @@ def index_select(src: SparseTensor, dim: int,
if value is not None:
value = value[perm]
sparse_sizes = torch.Size([idx.size(0), src.sparse_size(1)])
sparse_sizes = (idx.size(0), src.sparse_size(1))
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
......@@ -61,7 +61,7 @@ def index_select(src: SparseTensor, dim: int,
if value is not None:
value = value[perm][csc2csr]
sparse_sizes = torch.Size([src.sparse_size(0), idx.size(0)])
sparse_sizes = (src.sparse_size(0), idx.size(0))
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
......
......@@ -27,7 +27,7 @@ def masked_select(src: SparseTensor, dim: int,
if value is not None:
value = value[mask]
sparse_sizes = torch.Size([rowcount.size(0), src.sparse_size(1)])
sparse_sizes = (rowcount.size(0), src.sparse_size(1))
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
......@@ -54,7 +54,7 @@ def masked_select(src: SparseTensor, dim: int,
if value is not None:
value = value[csr2csc][mask][csc2csr]
sparse_sizes = torch.Size([src.sparse_size(0), colcount.size(0)])
sparse_sizes = (src.sparse_size(0), colcount.size(0))
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
......
......@@ -82,7 +82,7 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
rowptrA, colA, valueA, rowptrB, colB, valueB, K)
return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC,
sparse_sizes=torch.Size([M, K]), is_sorted=True)
sparse_sizes=(M, K), is_sorted=True)
def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
......
......@@ -30,7 +30,7 @@ def narrow(src: SparseTensor, dim: int, start: int,
if value is not None:
value = value.narrow(0, row_start, row_length)
sparse_sizes = torch.Size([length, src.sparse_size(1)])
sparse_sizes = (length, src.sparse_size(1))
rowcount = src.storage._rowcount
if rowcount is not None:
......@@ -53,7 +53,7 @@ def narrow(src: SparseTensor, dim: int, start: int,
if value is not None:
value = value[mask]
sparse_sizes = torch.Size([src.sparse_size(0), length])
sparse_sizes = (src.sparse_size(0), length)
colptr = src.storage._colptr
if colptr is not None:
......
......@@ -23,9 +23,9 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
"""
A = SparseTensor(row=indexA[0], col=indexA[1], value=valueA,
sparse_sizes=torch.Size([m, k]), is_sorted=not coalesced)
sparse_sizes=(m, k), is_sorted=not coalesced)
B = SparseTensor(row=indexB[0], col=indexB[1], value=valueB,
sparse_sizes=torch.Size([k, n]), is_sorted=not coalesced)
sparse_sizes=(k, n), is_sorted=not coalesced)
C = matmul(A, B)
row, col, value = C.coo()
......
import warnings
from typing import Optional, List
from typing import Optional, List, Tuple
import torch
from torch_scatter import segment_csr, scatter_add
......@@ -23,7 +23,7 @@ class SparseStorage(object):
_rowptr: Optional[torch.Tensor]
_col: torch.Tensor
_value: Optional[torch.Tensor]
_sparse_sizes: List[int]
_sparse_sizes: Tuple[int, int]
_rowcount: Optional[torch.Tensor]
_colptr: Optional[torch.Tensor]
_colcount: Optional[torch.Tensor]
......@@ -34,7 +34,7 @@ class SparseStorage(object):
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[List[int]] = None,
sparse_sizes: Optional[Tuple[int, int]] = None,
rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None,
......@@ -56,7 +56,7 @@ class SparseStorage(object):
else:
raise ValueError
N = col.max().item() + 1
sparse_sizes = torch.Size([int(M), int(N)])
sparse_sizes = (int(M), int(N))
else:
assert len(sparse_sizes) == 2
......@@ -118,7 +118,7 @@ class SparseStorage(object):
self._rowptr = rowptr
self._col = col
self._value = value
self._sparse_sizes = sparse_sizes
self._sparse_sizes = tuple(sparse_sizes)
self._rowcount = rowcount
self._colptr = colptr
self._colcount = colcount
......@@ -218,13 +218,13 @@ class SparseStorage(object):
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def sparse_sizes(self) -> List[int]:
def sparse_sizes(self) -> Tuple[int, int]:
return self._sparse_sizes
def sparse_size(self, dim: int) -> int:
return self._sparse_sizes[dim]
def sparse_resize(self, sparse_sizes: List[int]):
def sparse_resize(self, sparse_sizes: Tuple[int, int]):
assert len(sparse_sizes) == 2
old_sparse_sizes, nnz = self._sparse_sizes, self._col.numel()
......
......@@ -16,7 +16,8 @@ class SparseTensor(object):
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: List[int] = None, is_sorted: bool = False):
sparse_sizes: Optional[Tuple[int, int]] = None,
is_sorted: bool = False):
self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
value=value, sparse_sizes=sparse_sizes,
rowcount=None, colptr=None, colcount=None,
......@@ -45,7 +46,8 @@ class SparseTensor(object):
value = mat[row, col]
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=mat.size()[:2], is_sorted=True)
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True)
@classmethod
def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
......@@ -59,7 +61,8 @@ class SparseTensor(object):
value = mat._values()
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=mat.size()[:2], is_sorted=True)
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True)
@classmethod
def eye(self, M: int, N: Optional[int] = None,
......@@ -105,10 +108,9 @@ class SparseTensor(object):
csr2csc = csc2csr = row
storage: SparseStorage = SparseStorage(
row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=torch.Size([M, N]), rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc, csc2csr=csc2csr,
is_sorted=True)
row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=(M, N),
rowcount=rowcount, colptr=colptr, colcount=colcount,
csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True)
self = SparseTensor.__new__(SparseTensor)
self.storage = storage
......@@ -160,13 +162,13 @@ class SparseTensor(object):
layout: Optional[str] = None):
return self.from_storage(self.storage.set_value(value, layout))
def sparse_sizes(self) -> List[int]:
def sparse_sizes(self) -> Tuple[int, int]:
return self.storage.sparse_sizes()
def sparse_size(self, dim: int) -> int:
return self.storage.sparse_sizes()[dim]
def sparse_resize(self, sparse_sizes: List[int]):
def sparse_resize(self, sparse_sizes: Tuple[int, int]):
return self.from_storage(self.storage.sparse_resize(sparse_sizes))
def is_coalesced(self) -> bool:
......@@ -206,11 +208,12 @@ class SparseTensor(object):
return self.set_value(value, layout='coo')
def sizes(self) -> List[int]:
sizes = self.sparse_sizes()
sparse_sizes = self.sparse_sizes()
value = self.storage.value()
if value is not None:
sizes = list(sizes) + list(value.size())[1:]
return sizes
return list(sparse_sizes) + list(value.size())[1:]
else:
return list(sparse_sizes)
def size(self, dim: int) -> int:
return self.sizes()[dim]
......@@ -268,7 +271,7 @@ class SparseTensor(object):
N = max(self.size(0), self.size(1))
out = SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=torch.Size([N, N]), is_sorted=False)
sparse_sizes=(N, N), is_sorted=False)
out = out.coalesce(reduce)
return out
......
......@@ -19,7 +19,7 @@ def t(src: SparseTensor) -> SparseTensor:
rowptr=src.storage._colptr,
col=row[csr2csc],
value=value,
sparse_sizes=torch.Size([sparse_sizes[1], sparse_sizes[0]]),
sparse_sizes=(sparse_sizes[1], sparse_sizes[0]),
rowcount=src.storage._colcount,
colptr=src.storage._rowptr,
colcount=src.storage._rowcount,
......@@ -53,7 +53,7 @@ def transpose(index, value, m, n, coalesced=True):
row, col = col, row
if coalesced:
sparse_sizes = torch.Size([n, m])
sparse_sizes = (n, m)
storage = SparseStorage(row=row, col=col, value=value,
sparse_sizes=sparse_sizes, is_sorted=False)
storage = storage.coalesce()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment