Commit 76bf1e8a authored by rusty1s's avatar rusty1s
Browse files

sparse tensor update

parent b5624cb8
......@@ -21,10 +21,17 @@ class cached_property(object):
class SparseStorage(object):
layouts = ['coo', 'csr', 'csc']
cache_keys = ['rowptr', 'colptr', 'csr_to_csc', 'csc_to_csr']
def __init__(self, index, value=None, sparse_size=None, rowptr=None,
colptr=None, csr_to_csc=None, csc_to_csr=None,
def __init__(self,
index,
value=None,
sparse_size=None,
rowptr=None,
colptr=None,
csr_to_csc=None,
csc_to_csr=None,
is_sorted=False):
assert index.dtype == torch.long
......@@ -90,7 +97,6 @@ class SparseStorage(object):
def col(self):
return self._index[1]
@property
def has_value(self):
return self._value is not None
......@@ -103,7 +109,7 @@ class SparseStorage(object):
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in ['coo', 'csr', 'csc']
assert layout in self.layouts
assert value.device == self._index.device
assert value.size(0) == self._index.size(1)
if value is not None and layout == 'csc':
......@@ -115,7 +121,7 @@ class SparseStorage(object):
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout in ['coo', 'csr', 'csc']
assert layout in self.layouts
assert value.device == self._index.device
assert value.size(0) == self._index.size(1)
if value is not None and layout == 'csc':
......@@ -153,7 +159,13 @@ class SparseStorage(object):
def csc_to_csr(self):
return self.csr_to_csc.argsort()
def compute_cache_(self, *args):
def is_coalesced(self):
raise NotImplementedError
def coalesce(self):
raise NotImplementedError
def fill_cache_(self, *args):
for arg in args or self.cache_keys:
getattr(self, arg)
return self
......@@ -163,18 +175,15 @@ class SparseStorage(object):
setattr(self, f'_{arg}', None)
return self
def __copy__(self):
return self.apply(lambda x: x)
def clone(self):
return self.apply(lambda x: x.clone())
def __copy__(self):
return self.clone()
def __deepcopy__(self, memo):
memo = memo.setdefault('SparseStorage', {})
if self._cdata in memo:
return memo[self._cdata]
new_storage = self.clone()
memo[self._cdata] = new_storage
memo[id(self)] = new_storage
return new_storage
def apply_value_(self, func):
......@@ -198,6 +207,7 @@ class SparseStorage(object):
self._value = optional(func, self._value)
for key in self.cache_keys:
setattr(self, f'_{key}', optional(func, getattr(self, f'_{key}')))
return self
def apply(self, func):
return self.__class__(
......@@ -211,6 +221,16 @@ class SparseStorage(object):
is_sorted=True,
)
def map(self, func):
data = [func(self.index)]
if self.has_value():
data += [func(self.value)]
data += [
func(getattr(self, f'_{key}')) for key in self.cache_keys
if getattr(self, f'_{key}')
]
return data
if __name__ == '__main__':
from torch_geometric.datasets import Reddit, Planetoid # noqa
......@@ -225,18 +245,19 @@ if __name__ == '__main__':
storage = SparseStorage(edge_index, is_sorted=True)
t = time.perf_counter()
storage.compute_cache_()
storage.fill_cache_()
print(time.perf_counter() - t)
t = time.perf_counter()
storage.clear_cache_()
storage.compute_cache_()
storage.fill_cache_()
print(time.perf_counter() - t)
print(storage)
storage = storage.clone()
print(storage)
# storage = copy.copy(storage)
# storage = storage.clone()
# print(storage)
# storage = copy.deepcopy(storage)
# print(storage)
storage.compute_cache_()
storage = copy.copy(storage)
print(storage)
print(id(storage))
storage = copy.deepcopy(storage)
print(storage)
storage.fill_cache_()
storage.clear_cache_()
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment