Commit 64b0ae30 authored by rusty1s's avatar rusty1s
Browse files

storage done

parent f59fe649
import warnings import warnings
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Union, Any
import torch import torch
from torch_scatter import segment_csr, scatter_add from torch_scatter import segment_csr, scatter_add
from torch_sparse.utils import Final from torch_sparse.utils import Final, is_scalar
__cache__ = {'enabled': True} # __cache__ = {'enabled': True}
# def is_cache_enabled():
# return __cache__['enabled']
def is_cache_enabled(): # def set_cache_enabled(mode):
return __cache__['enabled'] # __cache__['enabled'] = mode
# class no_cache(object):
# def __enter__(self):
# self.prev = is_cache_enabled()
# set_cache_enabled(False)
def set_cache_enabled(mode): # def __exit__(self, *args):
__cache__['enabled'] = mode # set_cache_enabled(self.prev)
# return False
# def __call__(self, func):
# def decorate_no_cache(*args, **kwargs):
# with self:
# return func(*args, **kwargs)
class no_cache(object): # return decorate_no_cache
def __enter__(self):
self.prev = is_cache_enabled()
set_cache_enabled(False)
def __exit__(self, *args):
set_cache_enabled(self.prev)
return False
def __call__(self, func):
def decorate_no_cache(*args, **kwargs):
with self:
return func(*args, **kwargs)
return decorate_no_cache
# class cached_property(object):
# def __init__(self, func):
# self.func = func
# def __get__(self, obj, cls):
# value = getattr(obj, f'_{self.func.__name__}', None)
# if value is None:
# value = self.func(obj)
# if is_cache_enabled():
# setattr(obj, f'_{self.func.__name__}', value)
# return value
def optional(func, src): def optional(func, src):
...@@ -53,12 +37,12 @@ def optional(func, src): ...@@ -53,12 +37,12 @@ def optional(func, src):
layouts: Final[List[str]] = ['coo', 'csr', 'csc'] layouts: Final[List[str]] = ['coo', 'csr', 'csc']
def get_layout(layout=None): def get_layout(layout: Optional[str] = None) -> str:
if layout is None: if layout is None:
layout = 'coo' layout = 'coo'
warnings.warn('`layout` argument unset, using default layout ' warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.') '"coo". This may lead to unexpected behaviour.')
assert layout in layouts assert layout == 'coo' or layout == 'csr' or layout == 'csc'
return layout return layout
...@@ -237,78 +221,79 @@ class SparseStorage(object): ...@@ -237,78 +221,79 @@ class SparseStorage(object):
def value(self) -> Optional[torch.Tensor]: def value(self) -> Optional[torch.Tensor]:
return self._value return self._value
# def set_value_(self, value, layout=None, dtype=None): def set_value_(self, value: Optional[torch.Tensor],
# if isinstance(value, int) or isinstance(value, float): layout: Optional[str] = None):
# value = torch.full((self.col.numel(), ), dtype=dtype, if value is not None:
# device=self.col.device) if get_layout(layout) == 'csc2csr':
value = value[self.csc2csr()]
# elif torch.is_tensor(value) and get_layout(layout) == 'csc': value = value.contiguous()
# value = value[self.csc2csr] assert value.device == self._col.device
assert value.size(0) == self._col.numel()
# if torch.is_tensor(value):
# value = value if dtype is None else value.to(dtype)
# assert value.device == self.col.device
# assert value.size(0) == self.col.numel()
# self._value = value self._value = value
# return self return self
# def set_value(self, value, layout=None, dtype=None): def set_value(self, value: Optional[torch.Tensor],
# if isinstance(value, int) or isinstance(value, float): layout: Optional[str] = None):
# value = torch.full((self.col.numel(), ), dtype=dtype, if value is not None:
# device=self.col.device) if get_layout(layout) == 'csc2csr':
value = value[self.csc2csr()]
value = value.contiguous()
assert value.device == self._col.device
assert value.size(0) == self._col.numel()
# elif torch.is_tensor(value) and get_layout(layout) == 'csc': return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
# value = value[self.csc2csr] value=value, sparse_size=self._sparse_size,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
# if torch.is_tensor(value): def fill_value_(self, fill_value: float, dtype=Optional[torch.dtype]):
# value = value if dtype is None else value.to(dtype) value = torch.empty(self._col.numel(), dtype, device=self._col.device)
# assert value.device == self.col.device return self.set_value_(value.fill_(fill_value), layout='csr')
# assert value.size(0) == self.col.numel()
# return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col, def fill_value(self, fill_value: float, dtype=Optional[torch.dtype]):
# value=value, sparse_size=self._sparse_size, value = torch.empty(self._col.numel(), dtype, device=self._col.device)
# rowcount=self._rowcount, colptr=self._colptr, return self.set_value(value.fill_(fill_value), layout='csr')
# colcount=self._colcount, csr2csc=self._csr2csc,
# csc2csr=self._csc2csr, is_sorted=True)
def sparse_size(self) -> List[int]: def sparse_size(self) -> List[int]:
return self._sparse_size return self._sparse_size
# def sparse_resize(self, *sizes): def sparse_resize(self, sparse_size: List[int]):
# old_sparse_size, nnz = self.sparse_size, self.col.numel() assert len(sparse_size) == 2
old_sparse_size, nnz = self._sparse_size, self._col.numel()
# diff_0 = sizes[0] - old_sparse_size[0]
# rowcount, rowptr = self._rowcount, self._rowptr diff_0 = sparse_size[0] - old_sparse_size[0]
# if diff_0 > 0: rowcount, rowptr = self._rowcount, self._rowptr
# if rowptr is not None: if diff_0 > 0:
# rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)]) if rowptr is not None:
# if rowcount is not None: rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
# rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)]) if rowcount is not None:
# else: rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
# if rowptr is not None: else:
# rowptr = rowptr[:-diff_0] if rowptr is not None:
# if rowcount is not None: rowptr = rowptr[:-diff_0]
# rowcount = rowcount[:-diff_0] if rowcount is not None:
rowcount = rowcount[:-diff_0]
# diff_1 = sizes[1] - old_sparse_size[1]
# colcount, colptr = self._colcount, self._colptr diff_1 = sparse_size[1] - old_sparse_size[1]
# if diff_1 > 0: colcount, colptr = self._colcount, self._colptr
# if colptr is not None: if diff_1 > 0:
# colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)]) if colptr is not None:
# if colcount is not None: colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
# colcount = torch.cat([colcount, colcount.new_zeros(diff_1)]) if colcount is not None:
# else: colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
# if colptr is not None: else:
# colptr = colptr[:-diff_1] if colptr is not None:
# if colcount is not None: colptr = colptr[:-diff_1]
# colcount = colcount[:-diff_1] if colcount is not None:
colcount = colcount[:-diff_1]
# return self.__class__(row=self._row, rowptr=rowptr, col=self.col,
# value=self.value, sparse_size=sizes, return SparseStorage(row=self._row, rowptr=rowptr, col=self._col,
# rowcount=rowcount, colptr=colptr, value=self._value, sparse_size=sparse_size,
# colcount=colcount, csr2csc=self._csr2csc, rowcount=rowcount, colptr=colptr,
# csc2csr=self._csc2csr, is_sorted=True) colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def has_rowcount(self) -> bool: def has_rowcount(self) -> bool:
return self._rowcount is not None return self._rowcount is not None
...@@ -431,7 +416,7 @@ class SparseStorage(object): ...@@ -431,7 +416,7 @@ class SparseStorage(object):
self._csc2csr = None self._csc2csr = None
return self return self
def __copy__(self): def copy(self):
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col, return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=self._value, sparse_size=self._sparse_size, value=self._value, sparse_size=self._sparse_size,
rowcount=self._rowcount, colptr=self._colptr, rowcount=self._rowcount, colptr=self._colptr,
...@@ -468,6 +453,3 @@ class SparseStorage(object): ...@@ -468,6 +453,3 @@ class SparseStorage(object):
rowcount=rowcount, colptr=colptr, rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True) csc2csr=csc2csr, is_sorted=True)
def __deepcopy__(self, memo: Dict[str, Any]):
return self.clone()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment