Commit 64b0ae30 authored by rusty1s's avatar rusty1s
Browse files

storage done

parent f59fe649
import warnings
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict, Union, Any
import torch
from torch_scatter import segment_csr, scatter_add
from torch_sparse.utils import Final
from torch_sparse.utils import Final, is_scalar
__cache__ = {'enabled': True}
# __cache__ = {'enabled': True}
# def is_cache_enabled():
# return __cache__['enabled']
def is_cache_enabled():
return __cache__['enabled']
# def set_cache_enabled(mode):
# __cache__['enabled'] = mode
# class no_cache(object):
# def __enter__(self):
# self.prev = is_cache_enabled()
# set_cache_enabled(False)
def set_cache_enabled(mode):
__cache__['enabled'] = mode
# def __exit__(self, *args):
# set_cache_enabled(self.prev)
# return False
# def __call__(self, func):
# def decorate_no_cache(*args, **kwargs):
# with self:
# return func(*args, **kwargs)
class no_cache(object):
def __enter__(self):
self.prev = is_cache_enabled()
set_cache_enabled(False)
def __exit__(self, *args):
set_cache_enabled(self.prev)
return False
def __call__(self, func):
def decorate_no_cache(*args, **kwargs):
with self:
return func(*args, **kwargs)
return decorate_no_cache
# class cached_property(object):
# def __init__(self, func):
# self.func = func
# def __get__(self, obj, cls):
# value = getattr(obj, f'_{self.func.__name__}', None)
# if value is None:
# value = self.func(obj)
# if is_cache_enabled():
# setattr(obj, f'_{self.func.__name__}', value)
# return value
# return decorate_no_cache
def optional(func, src):
......@@ -53,12 +37,12 @@ def optional(func, src):
layouts: Final[List[str]] = ['coo', 'csr', 'csc']
def get_layout(layout=None):
def get_layout(layout: Optional[str] = None) -> str:
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in layouts
assert layout == 'coo' or layout == 'csr' or layout == 'csc'
return layout
......@@ -237,78 +221,79 @@ class SparseStorage(object):
def value(self) -> Optional[torch.Tensor]:
return self._value
# def set_value_(self, value, layout=None, dtype=None):
# if isinstance(value, int) or isinstance(value, float):
# value = torch.full((self.col.numel(), ), dtype=dtype,
# device=self.col.device)
# elif torch.is_tensor(value) and get_layout(layout) == 'csc':
# value = value[self.csc2csr]
# if torch.is_tensor(value):
# value = value if dtype is None else value.to(dtype)
# assert value.device == self.col.device
# assert value.size(0) == self.col.numel()
def set_value_(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc2csr':
value = value[self.csc2csr()]
value = value.contiguous()
assert value.device == self._col.device
assert value.size(0) == self._col.numel()
# self._value = value
# return self
self._value = value
return self
# def set_value(self, value, layout=None, dtype=None):
# if isinstance(value, int) or isinstance(value, float):
# value = torch.full((self.col.numel(), ), dtype=dtype,
# device=self.col.device)
def set_value(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc2csr':
value = value[self.csc2csr()]
value = value.contiguous()
assert value.device == self._col.device
assert value.size(0) == self._col.numel()
# elif torch.is_tensor(value) and get_layout(layout) == 'csc':
# value = value[self.csc2csr]
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=value, sparse_size=self._sparse_size,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
# if torch.is_tensor(value):
# value = value if dtype is None else value.to(dtype)
# assert value.device == self.col.device
# assert value.size(0) == self.col.numel()
def fill_value_(self, fill_value: float, dtype=Optional[torch.dtype]):
value = torch.empty(self._col.numel(), dtype, device=self._col.device)
return self.set_value_(value.fill_(fill_value), layout='csr')
# return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
# value=value, sparse_size=self._sparse_size,
# rowcount=self._rowcount, colptr=self._colptr,
# colcount=self._colcount, csr2csc=self._csr2csc,
# csc2csr=self._csc2csr, is_sorted=True)
def fill_value(self, fill_value: float, dtype=Optional[torch.dtype]):
value = torch.empty(self._col.numel(), dtype, device=self._col.device)
return self.set_value(value.fill_(fill_value), layout='csr')
def sparse_size(self) -> List[int]:
return self._sparse_size
# def sparse_resize(self, *sizes):
# old_sparse_size, nnz = self.sparse_size, self.col.numel()
# diff_0 = sizes[0] - old_sparse_size[0]
# rowcount, rowptr = self._rowcount, self._rowptr
# if diff_0 > 0:
# if rowptr is not None:
# rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
# if rowcount is not None:
# rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
# else:
# if rowptr is not None:
# rowptr = rowptr[:-diff_0]
# if rowcount is not None:
# rowcount = rowcount[:-diff_0]
# diff_1 = sizes[1] - old_sparse_size[1]
# colcount, colptr = self._colcount, self._colptr
# if diff_1 > 0:
# if colptr is not None:
# colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
# if colcount is not None:
# colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
# else:
# if colptr is not None:
# colptr = colptr[:-diff_1]
# if colcount is not None:
# colcount = colcount[:-diff_1]
# return self.__class__(row=self._row, rowptr=rowptr, col=self.col,
# value=self.value, sparse_size=sizes,
# rowcount=rowcount, colptr=colptr,
# colcount=colcount, csr2csc=self._csr2csc,
# csc2csr=self._csc2csr, is_sorted=True)
def sparse_resize(self, sparse_size: List[int]):
assert len(sparse_size) == 2
old_sparse_size, nnz = self._sparse_size, self._col.numel()
diff_0 = sparse_size[0] - old_sparse_size[0]
rowcount, rowptr = self._rowcount, self._rowptr
if diff_0 > 0:
if rowptr is not None:
rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
if rowcount is not None:
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
else:
if rowptr is not None:
rowptr = rowptr[:-diff_0]
if rowcount is not None:
rowcount = rowcount[:-diff_0]
diff_1 = sparse_size[1] - old_sparse_size[1]
colcount, colptr = self._colcount, self._colptr
if diff_1 > 0:
if colptr is not None:
colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
if colcount is not None:
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
else:
if colptr is not None:
colptr = colptr[:-diff_1]
if colcount is not None:
colcount = colcount[:-diff_1]
return SparseStorage(row=self._row, rowptr=rowptr, col=self._col,
value=self._value, sparse_size=sparse_size,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def has_rowcount(self) -> bool:
return self._rowcount is not None
......@@ -431,7 +416,7 @@ class SparseStorage(object):
self._csc2csr = None
return self
def __copy__(self):
def copy(self):
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=self._value, sparse_size=self._sparse_size,
rowcount=self._rowcount, colptr=self._colptr,
......@@ -468,6 +453,3 @@ class SparseStorage(object):
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def __deepcopy__(self, memo: Dict[str, Any]):
return self.clone()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment