Commit 87e30c53 authored by rusty1s's avatar rusty1s
Browse files

reset

parent 8741b4f7
...@@ -65,7 +65,7 @@ def masked_select(src: SparseTensor, dim: int, ...@@ -65,7 +65,7 @@ def masked_select(src: SparseTensor, dim: int,
else: else:
value = src.storage.value() value = src.storage.value()
if value is not None: if value is not None:
idx = mask.nonzero(as_tuple=False).flatten() idx = mask.nonzero().flatten()
return src.set_value(value.index_select(dim - 1, idx), return src.set_value(value.index_select(dim - 1, idx),
layout='coo') layout='coo')
else: else:
......
...@@ -379,7 +379,7 @@ class SparseStorage(object): ...@@ -379,7 +379,7 @@ class SparseStorage(object):
value = self._value value = self._value
if value is not None: if value is not None:
ptr = mask.nonzero(as_tuple=False).flatten() ptr = mask.nonzero().flatten()
ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))]) ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
value = segment_csr(value, ptr, reduce=reduce) value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value value = value[0] if isinstance(value, tuple) else value
......
...@@ -41,10 +41,9 @@ class SparseTensor(object): ...@@ -41,10 +41,9 @@ class SparseTensor(object):
@classmethod @classmethod
def from_dense(self, mat: torch.Tensor, has_value: bool = True): def from_dense(self, mat: torch.Tensor, has_value: bool = True):
if mat.dim() > 2: if mat.dim() > 2:
index = mat.abs().sum([i for i in range(2, mat.dim()) index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
]).nonzero(as_tuple=False)
else: else:
index = mat.nonzero(as_tuple=False) index = mat.nonzero()
index = index.t() index = index.t()
row = index[0] row = index[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment