Commit 904b1d48 authored by rusty1s's avatar rusty1s
Browse files

cached keys

parent 6e9cd9d6
......@@ -18,3 +18,7 @@ def partition_kway(
partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
return out, partptr, perm
SparseTensor.partition_kway = lambda self, num_parts: partition_kway(
self, num_parts)
......@@ -379,19 +379,22 @@ class SparseStorage(object):
self._csc2csr = None
return self
def num_cached_keys(self) -> int:
count = 0
def cached_keys(self) -> List[str]:
keys: List[str] = []
if self.has_rowcount():
count += 1
keys.append('rowcount')
if self.has_colptr():
count += 1
keys.append('colptr')
if self.has_colcount():
count += 1
keys.append('colcount')
if self.has_csr2csc():
count += 1
keys.append('csr2csc')
if self.has_csc2csr():
count += 1
return count
keys.append('csc2csr')
return keys
def num_cached_keys(self) -> int:
return len(self.cached_keys())
def copy(self):
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment