Commit 258dbf0f authored by rusty1s's avatar rusty1s
Browse files

from_dense with optional no values

parent 6b013fbe
...@@ -30,15 +30,21 @@ class SparseTensor(object): ...@@ -30,15 +30,21 @@ class SparseTensor(object):
return self return self
@classmethod @classmethod
def from_dense(self, mat: torch.Tensor): def from_dense(self, mat: torch.Tensor, has_value: bool = True):
if mat.dim() > 2: if mat.dim() > 2:
index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero() index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
else: else:
index = mat.nonzero() index = mat.nonzero()
index = index.t() index = index.t()
row, col = index[0], index[1] row = index[0]
return SparseTensor(row=row, rowptr=None, col=col, value=mat[row, col], col = index[1]
value: Optional[torch.Tensor] = None
if has_value:
value = mat[row, col]
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=mat.size()[:2], is_sorted=True) sparse_sizes=mat.size()[:2], is_sorted=True)
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment