"tests/python/vscode:/vscode.git/clone" did not exist on "4663cb0c91103be590c1bf3f389e5b4fd60bfa31"
Commit b5624cb8 authored by rusty1s's avatar rusty1s
Browse files

test

parent 9971227c
...@@ -163,6 +163,20 @@ class SparseStorage(object): ...@@ -163,6 +163,20 @@ class SparseStorage(object):
setattr(self, f'_{arg}', None) setattr(self, f'_{arg}', None)
return self return self
def clone(self):
return self.apply(lambda x: x.clone())
def __copy__(self):
return self.clone()
def __deepcopy__(self, memo):
memo = memo.setdefault('SparseStorage', {})
if self._cdata in memo:
return memo[self._cdata]
new_storage = self.clone()
memo[self._cdata] = new_storage
return new_storage
def apply_value_(self, func): def apply_value_(self, func):
self._value = optional(func, self._value) self._value = optional(func, self._value)
return self return self
...@@ -199,11 +213,13 @@ class SparseStorage(object): ...@@ -199,11 +213,13 @@ class SparseStorage(object):
if __name__ == '__main__': if __name__ == '__main__':
from torch_geometric.datasets import Reddit # noqa from torch_geometric.datasets import Reddit, Planetoid # noqa
import time # noqa import time # noqa
import copy # noqa
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset = Reddit('/tmp/Reddit') # dataset = Reddit('/tmp/Reddit')
dataset = Planetoid('/tmp/Cora', 'Cora')
data = dataset[0].to(device) data = dataset[0].to(device)
edge_index = data.edge_index edge_index = data.edge_index
...@@ -212,5 +228,15 @@ if __name__ == '__main__': ...@@ -212,5 +228,15 @@ if __name__ == '__main__':
storage.compute_cache_() storage.compute_cache_()
print(time.perf_counter() - t) print(time.perf_counter() - t)
t = time.perf_counter() t = time.perf_counter()
storage.clear_cache_()
storage.compute_cache_() storage.compute_cache_()
print(time.perf_counter() - t) print(time.perf_counter() - t)
print(storage)
storage = storage.clone()
print(storage)
# storage = copy.copy(storage)
# print(storage)
# storage = copy.deepcopy(storage)
# print(storage)
storage.compute_cache_()
storage.clear_cache_()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment