Commit f87afd09 authored by rusty1s's avatar rusty1s
Browse files

tensor and storage mostly jittable

parent 631eee37
......@@ -30,9 +30,10 @@ class MyCell(torch.nn.Module):
self.linear = torch.nn.Linear(2, 4)
# def forward(self, x: torch.Tensor, ptr: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, adj: SparseStorage) -> torch.Tensor:
out, _ = torch.ops.torch_sparse_cpu.spmm(adj.rowptr(), adj.col(), None,
x, 'sum')
def forward(self, x: torch.Tensor, adj: SparseTensor) -> torch.Tensor:
out, _ = torch.ops.torch_sparse_cpu.spmm(adj.storage.rowptr(),
adj.storage.col(), None, x,
'sum')
return out
......@@ -67,7 +68,10 @@ def test_jit():
rowptr = torch.tensor([0, 3, 6, 9])
col = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2])
adj = SparseStorage(rowptr=rowptr, col=col)
adj = SparseTensor(rowptr=rowptr, col=col)
scipy = adj.to_scipy(layout='csr')
mat = SparseTensor.from_scipy(scipy)
mat.fill_value_(2.3)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col)
......
import warnings
from typing import Optional, List, Dict, Union, Any
from typing import Optional, List
import torch
from torch_scatter import segment_csr, scatter_add
from torch_sparse.utils import Final, is_scalar
# __cache__ = {'enabled': True}
# def is_cache_enabled():
# return __cache__['enabled']
# def set_cache_enabled(mode):
# __cache__['enabled'] = mode
# class no_cache(object):
# def __enter__(self):
# self.prev = is_cache_enabled()
# set_cache_enabled(False)
# def __exit__(self, *args):
# set_cache_enabled(self.prev)
# return False
# def __call__(self, func):
# def decorate_no_cache(*args, **kwargs):
# with self:
# return func(*args, **kwargs)
# return decorate_no_cache
def optional(func, src):
return func(src) if src is not None else src
from torch_sparse.utils import Final
layouts: Final[List[str]] = ['coo', 'csr', 'csc']
......@@ -52,7 +23,7 @@ class SparseStorage(object):
_rowptr: Optional[torch.Tensor]
_col: torch.Tensor
_value: Optional[torch.Tensor]
_sparse_size: List[int]
_sparse_sizes: List[int]
_rowcount: Optional[torch.Tensor]
_colptr: Optional[torch.Tensor]
_colcount: Optional[torch.Tensor]
......@@ -63,7 +34,7 @@ class SparseStorage(object):
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_size: Optional[List[int]] = None,
sparse_sizes: Optional[List[int]] = None,
rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None,
......@@ -77,7 +48,7 @@ class SparseStorage(object):
assert col.dim() == 1
col = col.contiguous()
if sparse_size is None:
if sparse_sizes is None:
if rowptr is not None:
M = rowptr.numel() - 1
elif row is not None:
......@@ -85,9 +56,9 @@ class SparseStorage(object):
else:
raise ValueError
N = col.max().item() + 1
sparse_size = torch.Size([int(M), int(N)])
sparse_sizes = torch.Size([int(M), int(N)])
else:
assert len(sparse_size) == 2
assert len(sparse_sizes) == 2
if row is not None:
assert row.dtype == torch.long
......@@ -100,7 +71,7 @@ class SparseStorage(object):
assert rowptr.dtype == torch.long
assert rowptr.device == col.device
assert rowptr.dim() == 1
assert rowptr.numel() - 1 == sparse_size[0]
assert rowptr.numel() - 1 == sparse_sizes[0]
rowptr = rowptr.contiguous()
if value is not None:
......@@ -112,21 +83,21 @@ class SparseStorage(object):
assert rowcount.dtype == torch.long
assert rowcount.device == col.device
assert rowcount.dim() == 1
assert rowcount.numel() == sparse_size[0]
assert rowcount.numel() == sparse_sizes[0]
rowcount = rowcount.contiguous()
if colptr is not None:
assert colptr.dtype == torch.long
assert colptr.device == col.device
assert colptr.dim() == 1
assert colptr.numel() - 1 == sparse_size[1]
assert colptr.numel() - 1 == sparse_sizes[1]
colptr = colptr.contiguous()
if colcount is not None:
assert colcount.dtype == torch.long
assert colcount.device == col.device
assert colcount.dim() == 1
assert colcount.numel() == sparse_size[1]
assert colcount.numel() == sparse_sizes[1]
colcount = colcount.contiguous()
if csr2csc is not None:
......@@ -147,7 +118,7 @@ class SparseStorage(object):
self._rowptr = rowptr
self._col = col
self._value = value
self._sparse_size = sparse_size
self._sparse_sizes = sparse_sizes
self._rowcount = rowcount
self._colptr = colptr
self._colcount = colcount
......@@ -156,7 +127,7 @@ class SparseStorage(object):
if not is_sorted:
idx = col.new_zeros(col.numel() + 1)
idx[1:] = sparse_size[1] * self.row() + col
idx[1:] = sparse_sizes[1] * self.row() + col
if (idx[1:] < idx[:-1]).any():
perm = idx[1:].argsort()
self._row = self.row()[perm]
......@@ -203,10 +174,10 @@ class SparseStorage(object):
if row is not None:
if row.is_cuda:
rowptr = torch.ops.torch_sparse_cuda.ind2ptr(
row, self._sparse_size[0])
row, self._sparse_sizes[0])
else:
rowptr = torch.ops.torch_sparse_cpu.ind2ptr(
row, self._sparse_size[0])
row, self._sparse_sizes[0])
self._rowptr = rowptr
return rowptr
......@@ -243,27 +214,22 @@ class SparseStorage(object):
assert value.size(0) == self._col.numel()
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=value, sparse_size=self._sparse_size,
value=value, sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def fill_value_(self, fill_value: float, dtype=Optional[torch.dtype]):
value = torch.empty(self._col.numel(), dtype, device=self._col.device)
return self.set_value_(value.fill_(fill_value), layout='csr')
def sparse_sizes(self) -> List[int]:
return self._sparse_sizes
def fill_value(self, fill_value: float, dtype=Optional[torch.dtype]):
value = torch.empty(self._col.numel(), dtype, device=self._col.device)
return self.set_value(value.fill_(fill_value), layout='csr')
def sparse_size(self, dim: int) -> int:
return self._sparse_sizes[dim]
def sparse_size(self) -> List[int]:
return self._sparse_size
def sparse_resize(self, sparse_sizes: List[int]):
assert len(sparse_sizes) == 2
old_sparse_sizes, nnz = self._sparse_sizes, self._col.numel()
def sparse_resize(self, sparse_size: List[int]):
assert len(sparse_size) == 2
old_sparse_size, nnz = self._sparse_size, self._col.numel()
diff_0 = sparse_size[0] - old_sparse_size[0]
diff_0 = sparse_sizes[0] - old_sparse_sizes[0]
rowcount, rowptr = self._rowcount, self._rowptr
if diff_0 > 0:
if rowptr is not None:
......@@ -276,7 +242,7 @@ class SparseStorage(object):
if rowcount is not None:
rowcount = rowcount[:-diff_0]
diff_1 = sparse_size[1] - old_sparse_size[1]
diff_1 = sparse_sizes[1] - old_sparse_sizes[1]
colcount, colptr = self._colcount, self._colptr
if diff_1 > 0:
if colptr is not None:
......@@ -290,7 +256,7 @@ class SparseStorage(object):
colcount = colcount[:-diff_1]
return SparseStorage(row=self._row, rowptr=rowptr, col=self._col,
value=self._value, sparse_size=sparse_size,
value=self._value, sparse_sizes=sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
......@@ -319,9 +285,9 @@ class SparseStorage(object):
csr2csc = self._csr2csc
if csr2csc is not None:
colptr = torch.ops.torch_sparse_cpu.ind2ptr(
self._col[csr2csc], self._sparse_size[1])
self._col[csr2csc], self._sparse_sizes[1])
else:
colptr = self._col.new_zeros(self._sparse_size[1] + 1)
colptr = self._col.new_zeros(self._sparse_sizes[1] + 1)
torch.cumsum(self.colcount(), dim=0, out=colptr[1:])
self._colptr = colptr
return colptr
......@@ -340,7 +306,7 @@ class SparseStorage(object):
else:
raise NotImplementedError
# colcount = scatter_add(torch.ones_like(self._col), self._col,
# dim_size=self._sparse_size[1])
# dim_size=self._sparse_sizes[1])
self._colcount = colcount
return colcount
......@@ -352,7 +318,7 @@ class SparseStorage(object):
if csr2csc is not None:
return csr2csc
idx = self._sparse_size[0] * self._col + self.row()
idx = self._sparse_sizes[0] * self._col + self.row()
csr2csc = idx.argsort()
self._csr2csc = csr2csc
return csr2csc
......@@ -371,12 +337,12 @@ class SparseStorage(object):
def is_coalesced(self) -> bool:
idx = self._col.new_full((self._col.numel() + 1, ), -1)
idx[1:] = self._sparse_size[1] * self.row() + self._col
idx[1:] = self._sparse_sizes[1] * self.row() + self._col
return bool((idx[1:] > idx[:-1]).all())
def coalesce(self, reduce: str = "add"):
idx = self._col.new_full((self._col.numel() + 1, ), -1)
idx[1:] = self._sparse_size[1] * self.row() + self._col
idx[1:] = self._sparse_sizes[1] * self.row() + self._col
mask = idx[1:] > idx[:-1]
if mask.all(): # Skip if indices are already coalesced.
......@@ -394,7 +360,7 @@ class SparseStorage(object):
value = value[0] if isinstance(value, tuple) else value
return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_size=self._sparse_size, rowcount=None,
sparse_sizes=self._sparse_sizes, rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
......@@ -418,7 +384,8 @@ class SparseStorage(object):
def copy(self):
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=self._value, sparse_size=self._sparse_size,
value=self._value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
......@@ -430,6 +397,7 @@ class SparseStorage(object):
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.clone()
col = self._col.clone()
value = self._value
if value is not None:
value = value.clone()
......@@ -448,8 +416,178 @@ class SparseStorage(object):
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.clone()
return SparseStorage(row=row, rowptr=rowptr, col=self._col.clone(),
value=value, sparse_size=self._sparse_size,
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def type_as(self, tensor=torch.Tensor):
value = self._value
if value is not None:
if tensor.dtype == value.dtype:
return self
else:
return self.set_value(value.type_as(tensor), layout='coo')
else:
return self
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
if tensor.device == self._col.device:
return self
row = self._row
if row is not None:
row = row.to(tensor.device, non_blocking=non_blocking)
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.to(tensor.device, non_blocking=non_blocking)
col = self._col.to(tensor.device, non_blocking=non_blocking)
value = self._value
if value is not None:
value = value.to(tensor.device, non_blocking=non_blocking)
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.to(tensor.device, non_blocking=non_blocking)
colptr = self._colptr
if colptr is not None:
colptr = colptr.to(tensor.device, non_blocking=non_blocking)
colcount = self._colcount
if colcount is not None:
colcount = colcount.to(tensor.device, non_blocking=non_blocking)
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.to(tensor.device, non_blocking=non_blocking)
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def pin_memory(self):
row = self._row
if row is not None:
row = row.pin_memory()
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.pin_memory()
col = self._col.pin_memory()
value = self._value
if value is not None:
value = value.pin_memory()
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.pin_memory()
colptr = self._colptr
if colptr is not None:
colptr = colptr.pin_memory()
colcount = self._colcount
if colcount is not None:
colcount = colcount.pin_memory()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.pin_memory()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.pin_memory()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def is_pinned(self) -> bool:
is_pinned = True
row = self._row
if row is not None:
is_pinned = is_pinned and row.is_pinned()
rowptr = self._rowptr
if rowptr is not None:
is_pinned = is_pinned and rowptr.is_pinned()
is_pinned = self._col.is_pinned()
value = self._value
if value is not None:
is_pinned = is_pinned and value.is_pinned()
rowcount = self._rowcount
if rowcount is not None:
is_pinned = is_pinned and rowcount.is_pinned()
colptr = self._colptr
if colptr is not None:
is_pinned = is_pinned and colptr.is_pinned()
colcount = self._colcount
if colcount is not None:
is_pinned = is_pinned and colcount.is_pinned()
csr2csc = self._csr2csc
if csr2csc is not None:
is_pinned = is_pinned and csr2csc.is_pinned()
csc2csr = self._csc2csr
if csc2csr is not None:
is_pinned = is_pinned and csc2csr.is_pinned()
return is_pinned
@torch.jit.ignore
def share_memory_(self) -> SparseStorage:
row = self._row
if row is not None:
row.share_memory_()
rowptr = self._rowptr
if rowptr is not None:
rowptr.share_memory_()
self._col.share_memory_()
value = self._value
if value is not None:
value.share_memory_()
rowcount = self._rowcount
if rowcount is not None:
rowcount.share_memory_()
colptr = self._colptr
if colptr is not None:
colptr.share_memory_()
colcount = self._colcount
if colcount is not None:
colcount.share_memory_()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc.share_memory_()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr.share_memory_()
@torch.jit.ignore
def is_shared(self) -> bool:
is_shared = True
row = self._row
if row is not None:
is_shared = is_shared and row.is_shared()
rowptr = self._rowptr
if rowptr is not None:
is_shared = is_shared and rowptr.is_shared()
is_shared = is_shared and self._col.is_shared()
value = self._value
if value is not None:
is_shared = is_shared and value.is_shared()
rowcount = self._rowcount
if rowcount is not None:
is_shared = is_shared and rowcount.is_shared()
colptr = self._colptr
if colptr is not None:
is_shared = is_shared and colptr.is_shared()
colcount = self._colcount
if colcount is not None:
is_shared = is_shared and colcount.is_shared()
csr2csc = self._csr2csc
if csr2csc is not None:
is_shared = is_shared and csr2csc.is_shared()
csc2csr = self._csc2csr
if csc2csr is not None:
is_shared = is_shared and csc2csr.is_shared()
return is_shared
SparseStorage.share_memory_ = share_memory_
SparseStorage.is_shared = is_shared
from textwrap import indent
# from textwrap import indent
from typing import Optional, List, Tuple, Union
import torch
import scipy.sparse
from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.transpose import t
from torch_sparse.narrow import narrow
from torch_sparse.select import select
from torch_sparse.index_select import index_select, index_select_nnz
from torch_sparse.masked_select import masked_select, masked_select_nnz
import torch_sparse.reduce
from torch_sparse.diag import remove_diag, set_diag
from torch_sparse.matmul import matmul
from torch_sparse.add import add, add_, add_nnz, add_nnz_
from torch_sparse.mul import mul, mul_, mul_nnz, mul_nnz_
# from torch_sparse.transpose import t
# from torch_sparse.narrow import narrow
# from torch_sparse.select import select
# from torch_sparse.index_select import index_select, index_select_nnz
# from torch_sparse.masked_select import masked_select, masked_select_nnz
# import torch_sparse.reduce
# from torch_sparse.diag import remove_diag, set_diag
# from torch_sparse.matmul import matmul
# from torch_sparse.add import add, add_, add_nnz, add_nnz_
# from torch_sparse.mul import mul, mul_, mul_nnz, mul_nnz_
from torch_sparse.utils import is_scalar
@torch.jit.script
class SparseTensor(object):
def __init__(self, row=None, rowptr=None, col=None, value=None,
sparse_size=None, is_sorted=False):
storage: SparseStorage
def __init__(self, row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: List[int] = None, is_sorted: bool = False):
self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
value=value, sparse_size=sparse_size,
value=value, sparse_sizes=sparse_sizes,
rowcount=None, colptr=None, colcount=None,
csr2csc=None, csc2csr=None,
is_sorted=is_sorted)
@classmethod
def from_storage(self, storage):
def from_storage(self, storage: SparseStorage):
self = SparseTensor.__new__(SparseTensor)
self.storage = storage
return self
@classmethod
def from_dense(self, mat):
def from_dense(self, mat: torch.Tensor):
if mat.dim() > 2:
index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
else:
index = mat.nonzero()
index = index.t()
row, col = index.t().contiguous()
return SparseTensor(row=row, col=col, value=mat[row, col],
sparse_size=mat.size()[:2], is_sorted=True)
row, col = index[0], index[1]
return SparseTensor(row=row, rowptr=None, col=col, value=mat[row, col],
sparse_sizes=mat.size()[:2], is_sorted=True)
@classmethod
def from_torch_sparse_coo_tensor(self, mat, is_sorted=False):
row, col = mat._indices()
return SparseTensor(row=row, col=col, value=mat._values(),
sparse_size=mat.size()[:2], is_sorted=is_sorted)
def from_torch_sparse_coo_tensor(self, mat: torch.Tensor):
mat = mat.coalesce()
index = mat._indices()
row, col = index[0], index[1]
return SparseTensor(row=row, rowptr=None, col=col, value=mat._values(),
sparse_sizes=mat.size()[:2], is_sorted=True)
@classmethod
def from_scipy(self, mat):
colptr = None
if isinstance(mat, scipy.sparse.csc_matrix):
colptr = torch.from_numpy(mat.indptr).to(torch.long)
mat = mat.tocsr() # Pre-sort.
rowptr = torch.from_numpy(mat.indptr).to(torch.long)
mat = mat.tocoo()
row = torch.from_numpy(mat.row).to(torch.long)
col = torch.from_numpy(mat.col).to(torch.long)
value = torch.from_numpy(mat.data)
sparse_size = mat.shape[:2]
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_size=sparse_size, colptr=colptr,
is_sorted=True)
def eye(self, M: int, N: Optional[int] = None,
options: Optional[torch.Tensor] = None, has_value: bool = True,
fill_cache: bool = False):
return SparseTensor.from_storage(storage)
@classmethod
def eye(self, M, N=None, device=None, dtype=None, has_value=True,
fill_cache=False):
N = M if N is None else N
row = torch.arange(min(M, N), device=device)
rowptr = torch.arange(M + 1, device=device)
if M > N:
rowptr[row.size(0) + 1:] = row.size(0)
if options is not None:
row = torch.arange(min(M, N), device=options.device)
else:
row = torch.arange(min(M, N))
col = row
value = None
rowptr = torch.arange(M + 1, dtype=torch.long, device=row.device)
if M > N:
rowptr[N + 1:] = M
value: Optional[torch.Tensor] = None
if has_value:
value = torch.ones(row.size(0), dtype=dtype, device=device)
if options is not None:
value = torch.ones(row.numel(), dtype=options.dtype,
device=row.device)
else:
value = torch.ones(row.numel(), device=row.device)
rowcount: Optional[torch.Tensor] = None
colptr: Optional[torch.Tensor] = None
colcount: Optional[torch.Tensor] = None
csr2csc: Optional[torch.Tensor] = None
csc2csr: Optional[torch.Tensor] = None
rowcount = colptr = colcount = csr2csc = csc2csr = None
if fill_cache:
rowcount = row.new_ones(M)
rowcount = torch.ones(M, dtype=torch.long, device=row.device)
if M > N:
rowcount[row.size(0):] = 0
colptr = torch.arange(N + 1, device=device)
colcount = col.new_ones(N)
rowcount[N:] = 0
colptr = torch.arange(N + 1, dtype=torch.long, device=row.device)
colcount = torch.ones(N, dtype=torch.long, device=row.device)
if N > M:
colptr[col.size(0) + 1:] = col.size(0)
colcount[col.size(0):] = 0
colptr[M + 1:] = M
colcount[M:] = 0
csr2csc = csc2csr = row
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_size=torch.Size([M, N]),
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return SparseTensor.from_storage(storage)
storage: SparseStorage = SparseStorage(
row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=torch.Size([M, N]), rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc, csc2csr=csc2csr,
is_sorted=True)
def __copy__(self):
self = SparseTensor.__new__(SparseTensor)
self.storage = storage
return self
def copy(self):
return self.from_storage(self.storage)
def clone(self):
return self.from_storage(self.storage.clone())
def __deepcopy__(self, memo):
new_sparse_tensor = self.clone()
memo[id(self)] = new_sparse_tensor
return new_sparse_tensor
def type_as(self, tensor=torch.Tensor):
value = self.storage._value
if value is None or tensor.dtype == value.dtype:
return self
return self.from_storage(self.storage.type_as(tensor))
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
if tensor.device == self.device():
return self
return self.from_storage(self.storage.device_as(tensor, non_blocking))
# Formats #################################################################
def coo(self):
return self.storage.row, self.storage.col, self.storage.value
def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.storage.row(), self.storage.col(), self.storage.value()
def csr(self):
return self.storage.rowptr, self.storage.col, self.storage.value
def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.storage.rowptr(), self.storage.col(), self.storage.value()
def csc(self):
perm = self.storage.csr2csc # Compute `csr2csc` first.
return (self.storage.colptr, self.storage.row[perm],
self.storage.value[perm] if self.has_value() else None)
def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
perm = self.storage.csr2csc()
value = self.storage.value()
if value is not None:
value = value[perm]
return self.storage.colptr(), self.storage.row()[perm], value
# Storage inheritance #####################################################
def has_value(self):
def has_value(self) -> bool:
return self.storage.has_value()
def set_value_(self, value, layout=None, dtype=None):
self.storage.set_value_(value, layout, dtype)
def set_value_(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
self.storage.set_value_(value, layout)
return self
def set_value(self, value, layout=None, dtype=None):
return self.from_storage(self.storage.set_value(value, layout, dtype))
def set_value(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
return self.from_storage(self.storage.set_value(value, layout))
def sparse_sizes(self) -> List[int]:
return self.storage.sparse_sizes()
def sparse_size(self, dim=None):
sparse_size = self.storage.sparse_size
return sparse_size if dim is None else sparse_size[dim]
def sparse_size(self, dim: int) -> int:
return self.storage.sparse_sizes()[dim]
def sparse_resize(self, *sizes):
return self.from_storage(self.storage.sparse_resize(*sizes))
def sparse_resize(self, sparse_sizes: List[int]):
return self.from_storage(self.storage.sparse_resize(sparse_sizes))
def is_coalesced(self):
def is_coalesced(self) -> bool:
return self.storage.is_coalesced()
def coalesce(self, reduce='add'):
def coalesce(self, reduce: str = "add"):
return self.from_storage(self.storage.coalesce(reduce))
def cached_keys(self):
return self.storage.cached_keys()
def fill_cache_(self, *args):
self.storage.fill_cache_(*args)
def fill_cache_(self):
self.storage.fill_cache_()
return self
def clear_cache_(self, *args):
self.storage.clear_cache_(*args)
def clear_cache_(self):
self.storage.clear_cache_()
return self
# Utility functions #######################################################
def dim(self):
return len(self.size())
def size(self, dim=None):
size = self.sparse_size()
size += self.storage.value.size()[1:] if self.has_value() else ()
return size if dim is None else size[dim]
@property
def shape(self):
return self.size()
def nnz(self):
return self.storage.col.numel()
def fill_value_(self, fill_value: float,
options: Optional[torch.Tensor] = None):
if options is not None:
value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype,
device=self.device())
else:
value = torch.full((self.nnz(), ), fill_value,
device=self.device())
return self.set_value_(value, layout='coo')
def fill_value(self, fill_value: float,
options: Optional[torch.Tensor] = None):
if options is not None:
value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype,
device=self.device())
else:
value = torch.full((self.nnz(), ), fill_value,
device=self.device())
return self.set_value(value, layout='coo')
def sizes(self) -> List[int]:
sizes = self.sparse_sizes()
value = self.storage.value()
if value is not None:
sizes += value.size()[1:]
return sizes
def size(self, dim: int) -> int:
return self.sizes()[dim]
def dim(self) -> int:
return len(self.sizes())
def nnz(self) -> int:
return self.storage.col().numel()
def numel(self) -> int:
value = self.storage.value()
if value is not None:
return value.numel()
else:
return self.nnz()
def density(self):
def density(self) -> float:
return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))
def sparsity(self):
def sparsity(self) -> float:
return 1 - self.density()
def avg_row_length(self):
def avg_row_length(self) -> float:
return self.nnz() / self.sparse_size(0)
def avg_col_length(self):
def avg_col_length(self) -> float:
return self.nnz() / self.sparse_size(1)
def numel(self):
return self.value.numel() if self.has_value() else self.nnz()
def is_quadratic(self):
def is_quadratic(self) -> bool:
return self.sparse_size(0) == self.sparse_size(1)
def is_symmetric(self):
if not self.is_quadratic:
def is_symmetric(self) -> bool:
if not self.is_quadratic():
return False
rowptr, col, value1 = self.csr()
......@@ -207,296 +252,353 @@ class SparseTensor(object):
if (rowptr != colptr).any() or (col != row).any():
return False
if not self.has_value():
if value1 is None or value2 is None:
return True
return (value1 == value2).all().item()
else:
return bool((value1 == value2).all())
def detach_(self):
self.storage.apply_(lambda x: x.detach_())
value = self.storage.value()
if value is not None:
value.detach_()
return self
def detach(self):
return self.from_storage(self.storage.apply(lambda x: x.detach()))
@property
def requires_grad(self):
return self.storage.value.requires_grad if self.has_value() else False
value = self.storage.value()
if value is not None:
value = value.detach()
return self.set_value(value, layout='coo')
def requires_grad(self) -> bool:
value = self.storage.value()
if value is not None:
return value.requires_grad
else:
return False
def requires_grad_(self, requires_grad=True, dtype=None):
def requires_grad_(self, requires_grad: bool = True,
options: Optional[torch.Tensor] = None):
if requires_grad and not self.has_value():
self.storage.set_value_(1, dtype=dtype)
if self.has_value():
self.storage.value.requires_grad_(requires_grad)
self.fill_value_(1., options=options)
value = self.storage.value()
if value is not None:
value.requires_grad_(requires_grad)
return self
def pin_memory(self):
return self.from_storage(self.storage.apply(lambda x: x.pin_memory()))
return self.from_storage(self.storage.pin_memory())
def is_pinned(self):
return all(self.storage.map(lambda x: x.is_pinned()))
def is_pinned(self) -> bool:
return self.storage.is_pinned()
def share_memory_(self):
self.storage.apply_(lambda x: x.share_memory_())
return self
def is_shared(self):
return all(self.storage.map(lambda x: x.is_shared()))
def options(self) -> torch.Tensor:
value = self.storage.value()
if value is not None:
return value
else:
return torch.tensor(0., device=self.storage.col().device)
@property
def device(self):
return self.storage.col.device
return self.storage.col().device
def cpu(self):
return self.from_storage(self.storage.apply(lambda x: x.cpu()))
def cuda(self, device=None, non_blocking=False, **kwargs):
storage = self.storage.apply(
lambda x: x.cuda(device, non_blocking, **kwargs))
return self.from_storage(storage)
@property
def is_cuda(self):
return self.storage.col.is_cuda
@property
def dtype(self):
return self.storage.value.dtype if self.has_value() else None
def is_floating_point(self):
value = self.storage.value
return self.has_value() and torch.is_floating_point(value)
def type(self, dtype=None, non_blocking=False, **kwargs):
if dtype is None:
return self.dtype
if dtype == self.dtype:
return self
return self.device_as(torch.tensor(0.), non_blocking=False)
storage = self.storage.apply_value(
lambda x: x.type(dtype, non_blocking, **kwargs))
return self.from_storage(storage)
def to(self, *args, **kwargs):
args = list(args)
non_blocking = getattr(kwargs, 'non_blocking', False)
storage = None
if 'device' in kwargs:
device = kwargs['device']
del kwargs['device']
storage = self.storage.apply(
lambda x: x.to(device, non_blocking=non_blocking))
def cuda(self, options=Optional[torch.Tensor], non_blocking: bool = False):
if options is not None:
return self.device_as(options, non_blocking)
else:
for arg in args[:]:
if isinstance(arg, str) or isinstance(arg, torch.device):
storage = self.storage.apply(
lambda x: x.to(arg, non_blocking=non_blocking))
args.remove(arg)
options = torch.tensor(0.).cuda()
return self.device_as(options, non_blocking)
storage = self.storage if storage is None else storage
def is_cuda(self) -> bool:
return self.storage.col().is_cuda
if len(args) > 0 or len(kwargs) > 0:
storage = storage.apply_value(lambda x: x.type(*args, **kwargs))
def dtype(self):
return self.options().dtype
if storage == self.storage: # Nothing has been changed...
return self
else:
return self.from_storage(storage)
def is_floating_point(self) -> bool:
return torch.is_floating_point(self.options())
def bfloat16(self):
return self.type(torch.bfloat16)
return self.type_as(torch.tensor(0, dtype=torch.bfloat16))
def bool(self):
return self.type(torch.bool)
return self.type_as(torch.tensor(0, dtype=torch.bool))
def byte(self):
return self.type(torch.byte)
return self.type_as(torch.tensor(0, dtype=torch.uint8))
def char(self):
return self.type(torch.char)
return self.type_as(torch.tensor(0, dtype=torch.int8))
def half(self):
return self.type(torch.half)
return self.type_as(torch.tensor(0, dtype=torch.half))
def float(self):
return self.type(torch.float)
return self.type_as(torch.tensor(0, dtype=torch.float))
def double(self):
return self.type(torch.double)
return self.type_as(torch.tensor(0, dtype=torch.double))
def short(self):
return self.type(torch.short)
return self.type_as(torch.tensor(0, dtype=torch.short))
def int(self):
return self.type(torch.int)
return self.type_as(torch.tensor(0, dtype=torch.int))
def long(self):
return self.type(torch.long)
return self.type_as(torch.tensor(0, dtype=torch.long))
# Conversions #############################################################
def to_dense(self, dtype=None):
dtype = dtype or self.dtype
def to_dense(self, options: Optional[torch.Tensor] = None):
row, col, value = self.coo()
mat = torch.zeros(self.size(), dtype=dtype, device=self.device)
mat[row, col] = value if self.has_value() else 1
if options is not None:
mat = torch.zeros(self.sizes(), dtype=options.dtype,
device=self.device())
else:
mat = torch.zeros(self.sizes(), device=self.device())
if value is not None:
mat[row, col] = value
else:
mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
device=mat.device)
return mat
def to_torch_sparse_coo_tensor(self, dtype=None, requires_grad=False):
def to_torch_sparse_coo_tensor(self, options: Optional[torch.Tensor]):
row, col, value = self.coo()
index = torch.stack([row, col], dim=0)
if value is None:
value = torch.ones(self.nnz(), dtype=dtype, device=self.device)
return torch.sparse_coo_tensor(index, value, self.size(),
device=self.device,
requires_grad=requires_grad)
def to_scipy(self, layout=None, dtype=None):
assert self.dim() == 2
layout = get_layout(layout)
if not self.has_value():
ones = torch.ones(self.nnz(), dtype=dtype).numpy()
if layout == 'coo':
row, col, value = self.coo()
row = row.detach().cpu().numpy()
col = col.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.coo_matrix((value, (row, col)), self.size())
elif layout == 'csr':
rowptr, col, value = self.csr()
rowptr = rowptr.detach().cpu().numpy()
col = col.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.csr_matrix((value, col, rowptr), self.size())
elif layout == 'csc':
colptr, row, value = self.csc()
colptr = colptr.detach().cpu().numpy()
row = row.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.csc_matrix((value, row, colptr), self.size())
# Standard Operators ######################################################
def __getitem__(self, index):
index = list(index) if isinstance(index, tuple) else [index]
# More than one `Ellipsis` is not allowed...
if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
raise SyntaxError
dim = 0
out = self
while len(index) > 0:
item = index.pop(0)
if isinstance(item, int):
out = out.select(dim, item)
dim += 1
elif isinstance(item, slice):
if item.step is not None:
raise ValueError('Step parameter not yet supported.')
start = 0 if item.start is None else item.start
start = self.size(dim) + start if start < 0 else start
stop = self.size(dim) if item.stop is None else item.stop
stop = self.size(dim) + stop if stop < 0 else stop
out = out.narrow(dim, start, max(stop - start, 0))
dim += 1
elif torch.is_tensor(item):
if item.dtype == torch.bool:
out = out.masked_select(dim, item)
dim += 1
elif item.dtype == torch.long:
out = out.index_select(dim, item)
dim += 1
elif item == Ellipsis:
if self.dim() - len(index) < dim:
raise SyntaxError
dim = self.dim() - len(index)
if options is not None:
value = torch.ones(self.nnz(), dtype=options.dtype,
device=self.device())
else:
raise SyntaxError
value = torch.ones(self.nnz(), device=self.device())
return out
return torch.sparse_coo_tensor(index, value, self.sizes())
def __add__(self, other):
return self.add(other)
def __radd__(self, other):
return self.add(other)
# # Standard Operators ######################################################
def __iadd__(self, other):
return self.add_(other)
# def __getitem__(self, index):
# index = list(index) if isinstance(index, tuple) else [index]
# # More than one `Ellipsis` is not allowed...
# if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
# raise SyntaxError
def __mul__(self, other):
return self.mul(other)
# dim = 0
# out = self
# while len(index) > 0:
# item = index.pop(0)
# if isinstance(item, int):
# out = out.select(dim, item)
# dim += 1
# elif isinstance(item, slice):
# if item.step is not None:
# raise ValueError('Step parameter not yet supported.')
def __rmul__(self, other):
return self.mul(other)
# start = 0 if item.start is None else item.start
# start = self.size(dim) + start if start < 0 else start
def __imul__(self, other):
return self.mul_(other)
# stop = self.size(dim) if item.stop is None else item.stop
# stop = self.size(dim) + stop if stop < 0 else stop
def __matmul__(self, other):
return matmul(self, other, reduce='sum')
# out = out.narrow(dim, start, max(stop - start, 0))
# dim += 1
# elif torch.is_tensor(item):
# if item.dtype == torch.bool:
# out = out.masked_select(dim, item)
# dim += 1
# elif item.dtype == torch.long:
# out = out.index_select(dim, item)
# dim += 1
# elif item == Ellipsis:
# if self.dim() - len(index) < dim:
# raise SyntaxError
# dim = self.dim() - len(index)
# else:
# raise SyntaxError
# String Reputation #######################################################
# return out
def __repr__(self):
i = ' ' * 6
row, col, value = self.coo()
infos = []
infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
# def __add__(self, other):
# return self.add(other)
# def __radd__(self, other):
# return self.add(other)
# def __iadd__(self, other):
# return self.add_(other)
# def __mul__(self, other):
# return self.mul(other)
# def __rmul__(self, other):
# return self.mul(other)
if self.has_value():
infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
# def __imul__(self, other):
# return self.mul_(other)
infos += [
f'size={tuple(self.size())}, '
f'nnz={self.nnz()}, '
f'density={100 * self.density():.02f}%'
]
infos = ',\n'.join(infos)
# def __matmul__(self, other):
# return matmul(self, other, reduce='sum')
i = ' ' * (len(self.__class__.__name__) + 1)
return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
# # String Reputation #######################################################
# def __repr__(self):
# i = ' ' * 6
# row, col, value = self.coo()
# infos = []
# infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
# infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
# if self.has_value():
# infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
# infos += [
# f'size={tuple(self.size())}, '
# f'nnz={self.nnz()}, '
# f'density={100 * self.density():.02f}%'
# ]
# infos = ',\n'.join(infos)
# i = ' ' * (len(self.__class__.__name__) + 1)
# return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
# Bindings ####################################################################
SparseTensor.t = t
SparseTensor.narrow = narrow
SparseTensor.select = select
SparseTensor.index_select = index_select
SparseTensor.index_select_nnz = index_select_nnz
SparseTensor.masked_select = masked_select
SparseTensor.masked_select_nnz = masked_select_nnz
SparseTensor.reduction = torch_sparse.reduce.reduction
SparseTensor.sum = torch_sparse.reduce.sum
SparseTensor.mean = torch_sparse.reduce.mean
SparseTensor.min = torch_sparse.reduce.min
SparseTensor.max = torch_sparse.reduce.max
SparseTensor.remove_diag = remove_diag
SparseTensor.set_diag = set_diag
SparseTensor.matmul = matmul
SparseTensor.add = add
SparseTensor.add_ = add_
SparseTensor.add_nnz = add_nnz
SparseTensor.add_nnz_ = add_nnz_
SparseTensor.mul = mul
SparseTensor.mul_ = mul_
SparseTensor.mul_nnz = mul_nnz
SparseTensor.mul_nnz_ = mul_nnz_
# Fix for PyTorch<=1.3 (https://github.com/pytorch/pytorch/pull/31769):
# SparseTensor.t = t
# SparseTensor.narrow = narrow
# SparseTensor.select = select
# SparseTensor.index_select = index_select
# SparseTensor.index_select_nnz = index_select_nnz
# SparseTensor.masked_select = masked_select
# SparseTensor.masked_select_nnz = masked_select_nnz
# SparseTensor.reduction = torch_sparse.reduce.reduction
# SparseTensor.sum = torch_sparse.reduce.sum
# SparseTensor.mean = torch_sparse.reduce.mean
# SparseTensor.min = torch_sparse.reduce.min
# SparseTensor.max = torch_sparse.reduce.max
# SparseTensor.remove_diag = remove_diag
# SparseTensor.set_diag = set_diag
# SparseTensor.matmul = matmul
# SparseTensor.add = add
# SparseTensor.add_ = add_
# SparseTensor.add_nnz = add_nnz
# SparseTensor.add_nnz_ = add_nnz_
# SparseTensor.mul = mul
# SparseTensor.mul_ = mul_
# SparseTensor.mul_nnz = mul_nnz
# SparseTensor.mul_nnz_ = mul_nnz_
# Python Bindings #############################################################
Dtype = Optional[torch.dtype]
Device = Optional[Union[torch.device, str]]
@torch.jit.ignore
def share_memory_(self: SparseTensor) -> SparseTensor:
self.storage.share_memory_()
@torch.jit.ignore
def is_shared(self: SparseTensor) -> bool:
return self.storage.is_shared()
@torch.jit.ignore
def to(self, *args, **kwargs):
dtype: Dtype = getattr(kwargs, 'dtype', None)
device: Device = getattr(kwargs, 'device', None)
non_blocking: bool = getattr(kwargs, 'non_blocking', False)
for arg in args:
if isinstance(arg, str) or isinstance(arg, torch.device):
device = arg
if isinstance(arg, torch.dtype):
dtype = arg
if dtype is not None:
self = self.type_as(torch.tensor(0., dtype=dtype))
if device is not None:
self = self.device_as(torch.tensor(0., device=device), non_blocking)
return self
SparseTensor.share_memory_ = share_memory_
SparseTensor.is_shared = is_shared
SparseTensor.to = to
# Scipy Conversions ###########################################################
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.
csr_matrix, scipy.sparse.csc_matrix]
@torch.jit.ignore
def from_scipy(mat: ScipySparseMatrix) -> SparseTensor:
colptr = None
if isinstance(mat, scipy.sparse.csc_matrix):
colptr = torch.from_numpy(mat.indptr).to(torch.long)
mat = mat.tocsr()
rowptr = torch.from_numpy(mat.indptr).to(torch.long)
mat = mat.tocoo()
row = torch.from_numpy(mat.row).to(torch.long)
col = torch.from_numpy(mat.col).to(torch.long)
value = torch.from_numpy(mat.data)
sparse_sizes = mat.shape[:2]
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return SparseTensor.from_storage(storage)
@torch.jit.ignore
def to_scipy(self: SparseTensor, layout: Optional[str] = None,
dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
assert self.dim() == 2
layout = get_layout(layout)
if not self.has_value():
ones = torch.ones(self.nnz(), dtype=dtype).numpy()
if layout == 'coo':
row, col, value = self.coo()
row = row.detach().cpu().numpy()
col = col.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.coo_matrix((value, (row, col)), self.sizes())
elif layout == 'csr':
rowptr, col, value = self.csr()
rowptr = rowptr.detach().cpu().numpy()
col = col.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.csr_matrix((value, col, rowptr), self.sizes())
elif layout == 'csc':
colptr, row, value = self.csc()
colptr = colptr.detach().cpu().numpy()
row = row.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.csc_matrix((value, row, colptr), self.sizes())
SparseTensor.from_scipy = from_scipy
SparseTensor.to_scipy = to_scipy
# Hacky fixes #################################################################
# Fix standard operators of `torch.Tensor` for PyTorch<=1.3.
# https://github.com/pytorch/pytorch/pull/31769
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR < 1) or (TORCH_MAJOR == 1 and TORCH_MINOR < 4):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment