Commit 04593307 authored by rusty1s's avatar rusty1s
Browse files

diable caching

parent a3afa1dd
......@@ -3,7 +3,7 @@ from itertools import product
import pytest
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.storage import SparseStorage, no_cache
from .utils import dtypes, devices, tensor
......@@ -77,6 +77,17 @@ def test_caching(dtype, device):
assert storage._csr2csc is None
assert storage.cached_keys() == []
with no_cache():
storage.fill_cache_()
assert storage.cached_keys() == []
@no_cache()
def do_something(storage):
return storage.fill_cache_()
storage = do_something(storage)
assert storage.cached_keys() == []
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_utility(dtype, device):
......
......@@ -4,9 +4,32 @@ import torch
import torch_scatter
from torch_scatter import scatter_add, segment_add
__cache_flag__ = {'enabled': True}
def optional(func, src):
return func(src) if src is not None else src
def is_cache_enabled():
return __cache_flag__['enabled']
def set_cache_enabled(mode):
__cache_flag__['enabled'] = mode
class no_cache(object):
def __enter__(self):
self.prev = is_cache_enabled()
set_cache_enabled(False)
def __exit__(self, *args):
set_cache_enabled(self.prev)
return False
def __call__(self, func):
def decorate_no_cache(*args, **kwargs):
with self:
return func(*args, **kwargs)
return decorate_no_cache
class cached_property(object):
......@@ -17,10 +40,15 @@ class cached_property(object):
value = getattr(obj, f'_{self.func.__name__}', None)
if value is None:
value = self.func(obj)
if __cache_flag__['enabled']:
setattr(obj, f'_{self.func.__name__}', value)
return value
def optional(func, src):
return func(src) if src is not None else src
layouts = ['coo', 'csr', 'csc']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment