Commit f87afd09 authored by rusty1s's avatar rusty1s
Browse files

tensor and storage mostly jittable

parent 631eee37
...@@ -30,9 +30,10 @@ class MyCell(torch.nn.Module): ...@@ -30,9 +30,10 @@ class MyCell(torch.nn.Module):
self.linear = torch.nn.Linear(2, 4) self.linear = torch.nn.Linear(2, 4)
# def forward(self, x: torch.Tensor, ptr: torch.Tensor) -> torch.Tensor: # def forward(self, x: torch.Tensor, ptr: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, adj: SparseStorage) -> torch.Tensor: def forward(self, x: torch.Tensor, adj: SparseTensor) -> torch.Tensor:
out, _ = torch.ops.torch_sparse_cpu.spmm(adj.rowptr(), adj.col(), None, out, _ = torch.ops.torch_sparse_cpu.spmm(adj.storage.rowptr(),
x, 'sum') adj.storage.col(), None, x,
'sum')
return out return out
...@@ -67,7 +68,10 @@ def test_jit(): ...@@ -67,7 +68,10 @@ def test_jit():
rowptr = torch.tensor([0, 3, 6, 9]) rowptr = torch.tensor([0, 3, 6, 9])
col = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2]) col = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2])
adj = SparseStorage(rowptr=rowptr, col=col) adj = SparseTensor(rowptr=rowptr, col=col)
scipy = adj.to_scipy(layout='csr')
mat = SparseTensor.from_scipy(scipy)
mat.fill_value_(2.3)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col} # adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col) # foo = Foo(mat.storage.rowptr, mat.storage.col)
......
import warnings import warnings
from typing import Optional, List, Dict, Union, Any from typing import Optional, List
import torch import torch
from torch_scatter import segment_csr, scatter_add from torch_scatter import segment_csr, scatter_add
from torch_sparse.utils import Final, is_scalar from torch_sparse.utils import Final
# __cache__ = {'enabled': True}
# def is_cache_enabled():
# return __cache__['enabled']
# def set_cache_enabled(mode):
# __cache__['enabled'] = mode
# class no_cache(object):
# def __enter__(self):
# self.prev = is_cache_enabled()
# set_cache_enabled(False)
# def __exit__(self, *args):
# set_cache_enabled(self.prev)
# return False
# def __call__(self, func):
# def decorate_no_cache(*args, **kwargs):
# with self:
# return func(*args, **kwargs)
# return decorate_no_cache
def optional(func, src):
return func(src) if src is not None else src
layouts: Final[List[str]] = ['coo', 'csr', 'csc'] layouts: Final[List[str]] = ['coo', 'csr', 'csc']
...@@ -52,7 +23,7 @@ class SparseStorage(object): ...@@ -52,7 +23,7 @@ class SparseStorage(object):
_rowptr: Optional[torch.Tensor] _rowptr: Optional[torch.Tensor]
_col: torch.Tensor _col: torch.Tensor
_value: Optional[torch.Tensor] _value: Optional[torch.Tensor]
_sparse_size: List[int] _sparse_sizes: List[int]
_rowcount: Optional[torch.Tensor] _rowcount: Optional[torch.Tensor]
_colptr: Optional[torch.Tensor] _colptr: Optional[torch.Tensor]
_colcount: Optional[torch.Tensor] _colcount: Optional[torch.Tensor]
...@@ -63,7 +34,7 @@ class SparseStorage(object): ...@@ -63,7 +34,7 @@ class SparseStorage(object):
rowptr: Optional[torch.Tensor] = None, rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None, col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None,
sparse_size: Optional[List[int]] = None, sparse_sizes: Optional[List[int]] = None,
rowcount: Optional[torch.Tensor] = None, rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None, colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None, colcount: Optional[torch.Tensor] = None,
...@@ -77,7 +48,7 @@ class SparseStorage(object): ...@@ -77,7 +48,7 @@ class SparseStorage(object):
assert col.dim() == 1 assert col.dim() == 1
col = col.contiguous() col = col.contiguous()
if sparse_size is None: if sparse_sizes is None:
if rowptr is not None: if rowptr is not None:
M = rowptr.numel() - 1 M = rowptr.numel() - 1
elif row is not None: elif row is not None:
...@@ -85,9 +56,9 @@ class SparseStorage(object): ...@@ -85,9 +56,9 @@ class SparseStorage(object):
else: else:
raise ValueError raise ValueError
N = col.max().item() + 1 N = col.max().item() + 1
sparse_size = torch.Size([int(M), int(N)]) sparse_sizes = torch.Size([int(M), int(N)])
else: else:
assert len(sparse_size) == 2 assert len(sparse_sizes) == 2
if row is not None: if row is not None:
assert row.dtype == torch.long assert row.dtype == torch.long
...@@ -100,7 +71,7 @@ class SparseStorage(object): ...@@ -100,7 +71,7 @@ class SparseStorage(object):
assert rowptr.dtype == torch.long assert rowptr.dtype == torch.long
assert rowptr.device == col.device assert rowptr.device == col.device
assert rowptr.dim() == 1 assert rowptr.dim() == 1
assert rowptr.numel() - 1 == sparse_size[0] assert rowptr.numel() - 1 == sparse_sizes[0]
rowptr = rowptr.contiguous() rowptr = rowptr.contiguous()
if value is not None: if value is not None:
...@@ -112,21 +83,21 @@ class SparseStorage(object): ...@@ -112,21 +83,21 @@ class SparseStorage(object):
assert rowcount.dtype == torch.long assert rowcount.dtype == torch.long
assert rowcount.device == col.device assert rowcount.device == col.device
assert rowcount.dim() == 1 assert rowcount.dim() == 1
assert rowcount.numel() == sparse_size[0] assert rowcount.numel() == sparse_sizes[0]
rowcount = rowcount.contiguous() rowcount = rowcount.contiguous()
if colptr is not None: if colptr is not None:
assert colptr.dtype == torch.long assert colptr.dtype == torch.long
assert colptr.device == col.device assert colptr.device == col.device
assert colptr.dim() == 1 assert colptr.dim() == 1
assert colptr.numel() - 1 == sparse_size[1] assert colptr.numel() - 1 == sparse_sizes[1]
colptr = colptr.contiguous() colptr = colptr.contiguous()
if colcount is not None: if colcount is not None:
assert colcount.dtype == torch.long assert colcount.dtype == torch.long
assert colcount.device == col.device assert colcount.device == col.device
assert colcount.dim() == 1 assert colcount.dim() == 1
assert colcount.numel() == sparse_size[1] assert colcount.numel() == sparse_sizes[1]
colcount = colcount.contiguous() colcount = colcount.contiguous()
if csr2csc is not None: if csr2csc is not None:
...@@ -147,7 +118,7 @@ class SparseStorage(object): ...@@ -147,7 +118,7 @@ class SparseStorage(object):
self._rowptr = rowptr self._rowptr = rowptr
self._col = col self._col = col
self._value = value self._value = value
self._sparse_size = sparse_size self._sparse_sizes = sparse_sizes
self._rowcount = rowcount self._rowcount = rowcount
self._colptr = colptr self._colptr = colptr
self._colcount = colcount self._colcount = colcount
...@@ -156,7 +127,7 @@ class SparseStorage(object): ...@@ -156,7 +127,7 @@ class SparseStorage(object):
if not is_sorted: if not is_sorted:
idx = col.new_zeros(col.numel() + 1) idx = col.new_zeros(col.numel() + 1)
idx[1:] = sparse_size[1] * self.row() + col idx[1:] = sparse_sizes[1] * self.row() + col
if (idx[1:] < idx[:-1]).any(): if (idx[1:] < idx[:-1]).any():
perm = idx[1:].argsort() perm = idx[1:].argsort()
self._row = self.row()[perm] self._row = self.row()[perm]
...@@ -203,10 +174,10 @@ class SparseStorage(object): ...@@ -203,10 +174,10 @@ class SparseStorage(object):
if row is not None: if row is not None:
if row.is_cuda: if row.is_cuda:
rowptr = torch.ops.torch_sparse_cuda.ind2ptr( rowptr = torch.ops.torch_sparse_cuda.ind2ptr(
row, self._sparse_size[0]) row, self._sparse_sizes[0])
else: else:
rowptr = torch.ops.torch_sparse_cpu.ind2ptr( rowptr = torch.ops.torch_sparse_cpu.ind2ptr(
row, self._sparse_size[0]) row, self._sparse_sizes[0])
self._rowptr = rowptr self._rowptr = rowptr
return rowptr return rowptr
...@@ -243,27 +214,22 @@ class SparseStorage(object): ...@@ -243,27 +214,22 @@ class SparseStorage(object):
assert value.size(0) == self._col.numel() assert value.size(0) == self._col.numel()
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col, return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=value, sparse_size=self._sparse_size, value=value, sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr, rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc, colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True) csc2csr=self._csc2csr, is_sorted=True)
def fill_value_(self, fill_value: float, dtype=Optional[torch.dtype]): def sparse_sizes(self) -> List[int]:
value = torch.empty(self._col.numel(), dtype, device=self._col.device) return self._sparse_sizes
return self.set_value_(value.fill_(fill_value), layout='csr')
def fill_value(self, fill_value: float, dtype=Optional[torch.dtype]): def sparse_size(self, dim: int) -> int:
value = torch.empty(self._col.numel(), dtype, device=self._col.device) return self._sparse_sizes[dim]
return self.set_value(value.fill_(fill_value), layout='csr')
def sparse_size(self) -> List[int]: def sparse_resize(self, sparse_sizes: List[int]):
return self._sparse_size assert len(sparse_sizes) == 2
old_sparse_sizes, nnz = self._sparse_sizes, self._col.numel()
def sparse_resize(self, sparse_size: List[int]): diff_0 = sparse_sizes[0] - old_sparse_sizes[0]
assert len(sparse_size) == 2
old_sparse_size, nnz = self._sparse_size, self._col.numel()
diff_0 = sparse_size[0] - old_sparse_size[0]
rowcount, rowptr = self._rowcount, self._rowptr rowcount, rowptr = self._rowcount, self._rowptr
if diff_0 > 0: if diff_0 > 0:
if rowptr is not None: if rowptr is not None:
...@@ -276,7 +242,7 @@ class SparseStorage(object): ...@@ -276,7 +242,7 @@ class SparseStorage(object):
if rowcount is not None: if rowcount is not None:
rowcount = rowcount[:-diff_0] rowcount = rowcount[:-diff_0]
diff_1 = sparse_size[1] - old_sparse_size[1] diff_1 = sparse_sizes[1] - old_sparse_sizes[1]
colcount, colptr = self._colcount, self._colptr colcount, colptr = self._colcount, self._colptr
if diff_1 > 0: if diff_1 > 0:
if colptr is not None: if colptr is not None:
...@@ -290,7 +256,7 @@ class SparseStorage(object): ...@@ -290,7 +256,7 @@ class SparseStorage(object):
colcount = colcount[:-diff_1] colcount = colcount[:-diff_1]
return SparseStorage(row=self._row, rowptr=rowptr, col=self._col, return SparseStorage(row=self._row, rowptr=rowptr, col=self._col,
value=self._value, sparse_size=sparse_size, value=self._value, sparse_sizes=sparse_sizes,
rowcount=rowcount, colptr=colptr, rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=self._csr2csc, colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True) csc2csr=self._csc2csr, is_sorted=True)
...@@ -319,9 +285,9 @@ class SparseStorage(object): ...@@ -319,9 +285,9 @@ class SparseStorage(object):
csr2csc = self._csr2csc csr2csc = self._csr2csc
if csr2csc is not None: if csr2csc is not None:
colptr = torch.ops.torch_sparse_cpu.ind2ptr( colptr = torch.ops.torch_sparse_cpu.ind2ptr(
self._col[csr2csc], self._sparse_size[1]) self._col[csr2csc], self._sparse_sizes[1])
else: else:
colptr = self._col.new_zeros(self._sparse_size[1] + 1) colptr = self._col.new_zeros(self._sparse_sizes[1] + 1)
torch.cumsum(self.colcount(), dim=0, out=colptr[1:]) torch.cumsum(self.colcount(), dim=0, out=colptr[1:])
self._colptr = colptr self._colptr = colptr
return colptr return colptr
...@@ -340,7 +306,7 @@ class SparseStorage(object): ...@@ -340,7 +306,7 @@ class SparseStorage(object):
else: else:
raise NotImplementedError raise NotImplementedError
# colcount = scatter_add(torch.ones_like(self._col), self._col, # colcount = scatter_add(torch.ones_like(self._col), self._col,
# dim_size=self._sparse_size[1]) # dim_size=self._sparse_sizes[1])
self._colcount = colcount self._colcount = colcount
return colcount return colcount
...@@ -352,7 +318,7 @@ class SparseStorage(object): ...@@ -352,7 +318,7 @@ class SparseStorage(object):
if csr2csc is not None: if csr2csc is not None:
return csr2csc return csr2csc
idx = self._sparse_size[0] * self._col + self.row() idx = self._sparse_sizes[0] * self._col + self.row()
csr2csc = idx.argsort() csr2csc = idx.argsort()
self._csr2csc = csr2csc self._csr2csc = csr2csc
return csr2csc return csr2csc
...@@ -371,12 +337,12 @@ class SparseStorage(object): ...@@ -371,12 +337,12 @@ class SparseStorage(object):
def is_coalesced(self) -> bool: def is_coalesced(self) -> bool:
idx = self._col.new_full((self._col.numel() + 1, ), -1) idx = self._col.new_full((self._col.numel() + 1, ), -1)
idx[1:] = self._sparse_size[1] * self.row() + self._col idx[1:] = self._sparse_sizes[1] * self.row() + self._col
return bool((idx[1:] > idx[:-1]).all()) return bool((idx[1:] > idx[:-1]).all())
def coalesce(self, reduce: str = "add"): def coalesce(self, reduce: str = "add"):
idx = self._col.new_full((self._col.numel() + 1, ), -1) idx = self._col.new_full((self._col.numel() + 1, ), -1)
idx[1:] = self._sparse_size[1] * self.row() + self._col idx[1:] = self._sparse_sizes[1] * self.row() + self._col
mask = idx[1:] > idx[:-1] mask = idx[1:] > idx[:-1]
if mask.all(): # Skip if indices are already coalesced. if mask.all(): # Skip if indices are already coalesced.
...@@ -394,7 +360,7 @@ class SparseStorage(object): ...@@ -394,7 +360,7 @@ class SparseStorage(object):
value = value[0] if isinstance(value, tuple) else value value = value[0] if isinstance(value, tuple) else value
return SparseStorage(row=row, rowptr=None, col=col, value=value, return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_size=self._sparse_size, rowcount=None, sparse_sizes=self._sparse_sizes, rowcount=None,
colptr=None, colcount=None, csr2csc=None, colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True) csc2csr=None, is_sorted=True)
...@@ -418,7 +384,8 @@ class SparseStorage(object): ...@@ -418,7 +384,8 @@ class SparseStorage(object):
def copy(self): def copy(self):
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col, return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=self._value, sparse_size=self._sparse_size, value=self._value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr, rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc, colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True) csc2csr=self._csc2csr, is_sorted=True)
...@@ -430,6 +397,7 @@ class SparseStorage(object): ...@@ -430,6 +397,7 @@ class SparseStorage(object):
rowptr = self._rowptr rowptr = self._rowptr
if rowptr is not None: if rowptr is not None:
rowptr = rowptr.clone() rowptr = rowptr.clone()
col = self._col.clone()
value = self._value value = self._value
if value is not None: if value is not None:
value = value.clone() value = value.clone()
...@@ -448,8 +416,178 @@ class SparseStorage(object): ...@@ -448,8 +416,178 @@ class SparseStorage(object):
csc2csr = self._csc2csr csc2csr = self._csc2csr
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.clone() csc2csr = csc2csr.clone()
return SparseStorage(row=row, rowptr=rowptr, col=self._col.clone(), return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
value=value, sparse_size=self._sparse_size, sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def type_as(self, tensor=torch.Tensor):
value = self._value
if value is not None:
if tensor.dtype == value.dtype:
return self
else:
return self.set_value(value.type_as(tensor), layout='coo')
else:
return self
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
if tensor.device == self._col.device:
return self
row = self._row
if row is not None:
row = row.to(tensor.device, non_blocking=non_blocking)
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.to(tensor.device, non_blocking=non_blocking)
col = self._col.to(tensor.device, non_blocking=non_blocking)
value = self._value
if value is not None:
value = value.to(tensor.device, non_blocking=non_blocking)
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.to(tensor.device, non_blocking=non_blocking)
colptr = self._colptr
if colptr is not None:
colptr = colptr.to(tensor.device, non_blocking=non_blocking)
colcount = self._colcount
if colcount is not None:
colcount = colcount.to(tensor.device, non_blocking=non_blocking)
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.to(tensor.device, non_blocking=non_blocking)
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def pin_memory(self):
row = self._row
if row is not None:
row = row.pin_memory()
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.pin_memory()
col = self._col.pin_memory()
value = self._value
if value is not None:
value = value.pin_memory()
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.pin_memory()
colptr = self._colptr
if colptr is not None:
colptr = colptr.pin_memory()
colcount = self._colcount
if colcount is not None:
colcount = colcount.pin_memory()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.pin_memory()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.pin_memory()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr, rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True) csc2csr=csc2csr, is_sorted=True)
def is_pinned(self) -> bool:
is_pinned = True
row = self._row
if row is not None:
is_pinned = is_pinned and row.is_pinned()
rowptr = self._rowptr
if rowptr is not None:
is_pinned = is_pinned and rowptr.is_pinned()
is_pinned = self._col.is_pinned()
value = self._value
if value is not None:
is_pinned = is_pinned and value.is_pinned()
rowcount = self._rowcount
if rowcount is not None:
is_pinned = is_pinned and rowcount.is_pinned()
colptr = self._colptr
if colptr is not None:
is_pinned = is_pinned and colptr.is_pinned()
colcount = self._colcount
if colcount is not None:
is_pinned = is_pinned and colcount.is_pinned()
csr2csc = self._csr2csc
if csr2csc is not None:
is_pinned = is_pinned and csr2csc.is_pinned()
csc2csr = self._csc2csr
if csc2csr is not None:
is_pinned = is_pinned and csc2csr.is_pinned()
return is_pinned
@torch.jit.ignore
def share_memory_(self) -> SparseStorage:
row = self._row
if row is not None:
row.share_memory_()
rowptr = self._rowptr
if rowptr is not None:
rowptr.share_memory_()
self._col.share_memory_()
value = self._value
if value is not None:
value.share_memory_()
rowcount = self._rowcount
if rowcount is not None:
rowcount.share_memory_()
colptr = self._colptr
if colptr is not None:
colptr.share_memory_()
colcount = self._colcount
if colcount is not None:
colcount.share_memory_()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc.share_memory_()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr.share_memory_()
@torch.jit.ignore
def is_shared(self) -> bool:
is_shared = True
row = self._row
if row is not None:
is_shared = is_shared and row.is_shared()
rowptr = self._rowptr
if rowptr is not None:
is_shared = is_shared and rowptr.is_shared()
is_shared = is_shared and self._col.is_shared()
value = self._value
if value is not None:
is_shared = is_shared and value.is_shared()
rowcount = self._rowcount
if rowcount is not None:
is_shared = is_shared and rowcount.is_shared()
colptr = self._colptr
if colptr is not None:
is_shared = is_shared and colptr.is_shared()
colcount = self._colcount
if colcount is not None:
is_shared = is_shared and colcount.is_shared()
csr2csc = self._csr2csc
if csr2csc is not None:
is_shared = is_shared and csr2csc.is_shared()
csc2csr = self._csc2csr
if csc2csr is not None:
is_shared = is_shared and csc2csr.is_shared()
return is_shared
SparseStorage.share_memory_ = share_memory_
SparseStorage.is_shared = is_shared
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment