Commit bcc8106e authored by rusty1s's avatar rusty1s
Browse files

avg_bandwidth

parent 220672ec
...@@ -260,6 +260,10 @@ class SparseTensor(object): ...@@ -260,6 +260,10 @@ class SparseTensor(object):
row, col, _ = self.coo() row, col, _ = self.coo()
return int((row - col).abs_().max()) return int((row - col).abs_().max())
def avg_bandwidth(self) -> float:
row, col, _ = self.coo()
return float((row - col).abs_().to(torch.float).mean())
def bandwidth_proportion(self, bandwidth: int) -> float: def bandwidth_proportion(self, bandwidth: int) -> float:
row, col, _ = self.coo() row, col, _ = self.coo()
tmp = (row - col).abs_() tmp = (row - col).abs_()
...@@ -537,8 +541,8 @@ SparseTensor.__repr__ = __repr__ ...@@ -537,8 +541,8 @@ SparseTensor.__repr__ = __repr__
# Scipy Conversions ########################################################### # Scipy Conversions ###########################################################
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse. ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix,
csr_matrix, scipy.sparse.csc_matrix] scipy.sparse.csc_matrix]
@torch.jit.ignore @torch.jit.ignore
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment