"vscode:/vscode.git/clone" did not exist on "018df054f25349e68c14654ea4b01a2db042df8c"
Commit 24c599ea authored by rusty1s's avatar rusty1s
Browse files

has_value

parent 9d4fae0b
...@@ -48,11 +48,17 @@ class SparseTensor(object): ...@@ -48,11 +48,17 @@ class SparseTensor(object):
sparse_sizes=mat.size()[:2], is_sorted=True) sparse_sizes=mat.size()[:2], is_sorted=True)
@classmethod @classmethod
def from_torch_sparse_coo_tensor(self, mat: torch.Tensor): def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
has_value: bool = True):
mat = mat.coalesce() mat = mat.coalesce()
index = mat._indices() index = mat._indices()
row, col = index[0], index[1] row, col = index[0], index[1]
return SparseTensor(row=row, rowptr=None, col=col, value=mat._values(),
value: Optional[torch.Tensor] = None
if has_value:
value = mat._values()
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=mat.size()[:2], is_sorted=True) sparse_sizes=mat.size()[:2], is_sorted=True)
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment