Commit 76bf1e8a authored by rusty1s's avatar rusty1s
Browse files

sparse tensor update

parent b5624cb8
...@@ -21,10 +21,17 @@ class cached_property(object): ...@@ -21,10 +21,17 @@ class cached_property(object):
class SparseStorage(object): class SparseStorage(object):
layouts = ['coo', 'csr', 'csc']
cache_keys = ['rowptr', 'colptr', 'csr_to_csc', 'csc_to_csr'] cache_keys = ['rowptr', 'colptr', 'csr_to_csc', 'csc_to_csr']
def __init__(self, index, value=None, sparse_size=None, rowptr=None, def __init__(self,
colptr=None, csr_to_csc=None, csc_to_csr=None, index,
value=None,
sparse_size=None,
rowptr=None,
colptr=None,
csr_to_csc=None,
csc_to_csr=None,
is_sorted=False): is_sorted=False):
assert index.dtype == torch.long assert index.dtype == torch.long
...@@ -90,7 +97,6 @@ class SparseStorage(object): ...@@ -90,7 +97,6 @@ class SparseStorage(object):
def col(self): def col(self):
return self._index[1] return self._index[1]
@property
def has_value(self): def has_value(self):
return self._value is not None return self._value is not None
...@@ -103,7 +109,7 @@ class SparseStorage(object): ...@@ -103,7 +109,7 @@ class SparseStorage(object):
layout = 'coo' layout = 'coo'
warnings.warn('`layout` argument unset, using default layout ' warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.') '"coo". This may lead to unexpected behaviour.')
assert layout in ['coo', 'csr', 'csc'] assert layout in self.layouts
assert value.device == self._index.device assert value.device == self._index.device
assert value.size(0) == self._index.size(1) assert value.size(0) == self._index.size(1)
if value is not None and layout == 'csc': if value is not None and layout == 'csc':
...@@ -115,7 +121,7 @@ class SparseStorage(object): ...@@ -115,7 +121,7 @@ class SparseStorage(object):
layout = 'coo' layout = 'coo'
warnings.warn('`layout` argument unset, using default layout ' warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.') '"coo". This may lead to unexpected behaviour.')
assert layout in ['coo', 'csr', 'csc'] assert layout in self.layouts
assert value.device == self._index.device assert value.device == self._index.device
assert value.size(0) == self._index.size(1) assert value.size(0) == self._index.size(1)
if value is not None and layout == 'csc': if value is not None and layout == 'csc':
...@@ -153,7 +159,13 @@ class SparseStorage(object): ...@@ -153,7 +159,13 @@ class SparseStorage(object):
def csc_to_csr(self): def csc_to_csr(self):
return self.csr_to_csc.argsort() return self.csr_to_csc.argsort()
def compute_cache_(self, *args): def is_coalesced(self):
raise NotImplementedError
def coalesce(self):
raise NotImplementedError
def fill_cache_(self, *args):
for arg in args or self.cache_keys: for arg in args or self.cache_keys:
getattr(self, arg) getattr(self, arg)
return self return self
...@@ -163,18 +175,15 @@ class SparseStorage(object): ...@@ -163,18 +175,15 @@ class SparseStorage(object):
setattr(self, f'_{arg}', None) setattr(self, f'_{arg}', None)
return self return self
def __copy__(self):
return self.apply(lambda x: x)
def clone(self): def clone(self):
return self.apply(lambda x: x.clone()) return self.apply(lambda x: x.clone())
def __copy__(self):
return self.clone()
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
memo = memo.setdefault('SparseStorage', {})
if self._cdata in memo:
return memo[self._cdata]
new_storage = self.clone() new_storage = self.clone()
memo[self._cdata] = new_storage memo[id(self)] = new_storage
return new_storage return new_storage
def apply_value_(self, func): def apply_value_(self, func):
...@@ -198,6 +207,7 @@ class SparseStorage(object): ...@@ -198,6 +207,7 @@ class SparseStorage(object):
self._value = optional(func, self._value) self._value = optional(func, self._value)
for key in self.cache_keys: for key in self.cache_keys:
setattr(self, f'_{key}', optional(func, getattr(self, f'_{key}'))) setattr(self, f'_{key}', optional(func, getattr(self, f'_{key}')))
return self
def apply(self, func): def apply(self, func):
return self.__class__( return self.__class__(
...@@ -211,6 +221,16 @@ class SparseStorage(object): ...@@ -211,6 +221,16 @@ class SparseStorage(object):
is_sorted=True, is_sorted=True,
) )
def map(self, func):
data = [func(self.index)]
if self.has_value():
data += [func(self.value)]
data += [
func(getattr(self, f'_{key}')) for key in self.cache_keys
if getattr(self, f'_{key}')
]
return data
if __name__ == '__main__': if __name__ == '__main__':
from torch_geometric.datasets import Reddit, Planetoid # noqa from torch_geometric.datasets import Reddit, Planetoid # noqa
...@@ -225,18 +245,19 @@ if __name__ == '__main__': ...@@ -225,18 +245,19 @@ if __name__ == '__main__':
storage = SparseStorage(edge_index, is_sorted=True) storage = SparseStorage(edge_index, is_sorted=True)
t = time.perf_counter() t = time.perf_counter()
storage.compute_cache_() storage.fill_cache_()
print(time.perf_counter() - t) print(time.perf_counter() - t)
t = time.perf_counter() t = time.perf_counter()
storage.clear_cache_() storage.clear_cache_()
storage.compute_cache_() storage.fill_cache_()
print(time.perf_counter() - t) print(time.perf_counter() - t)
print(storage) print(storage)
storage = storage.clone() # storage = storage.clone()
print(storage)
# storage = copy.copy(storage)
# print(storage) # print(storage)
# storage = copy.deepcopy(storage) storage = copy.copy(storage)
# print(storage) print(storage)
storage.compute_cache_() print(id(storage))
storage = copy.deepcopy(storage)
print(storage)
storage.fill_cache_()
storage.clear_cache_() storage.clear_cache_()
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment