Commit bcc8106e authored by rusty1s's avatar rusty1s
Browse files

avg_bandwidth

parent 220672ec
......@@ -260,6 +260,10 @@ class SparseTensor(object):
row, col, _ = self.coo()
return int((row - col).abs_().max())
def avg_bandwidth(self) -> float:
row, col, _ = self.coo()
return float((row - col).abs_().to(torch.float).mean())
def bandwidth_proportion(self, bandwidth: int) -> float:
row, col, _ = self.coo()
tmp = (row - col).abs_()
......@@ -537,8 +541,8 @@ SparseTensor.__repr__ = __repr__
# Scipy Conversions ###########################################################
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.
csr_matrix, scipy.sparse.csc_matrix]
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix,
scipy.sparse.csc_matrix]
@torch.jit.ignore
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment