Commit 8741b4f7 authored by rusty1s's avatar rusty1s
Browse files

nonzero warnings

parent a1021ebd
......@@ -65,7 +65,7 @@ def masked_select(src: SparseTensor, dim: int,
else:
value = src.storage.value()
if value is not None:
idx = mask.nonzero().flatten()
idx = mask.nonzero(as_tuple=False).flatten()
return src.set_value(value.index_select(dim - 1, idx),
layout='coo')
else:
......
......@@ -379,7 +379,7 @@ class SparseStorage(object):
value = self._value
if value is not None:
ptr = mask.nonzero().flatten()
ptr = mask.nonzero(as_tuple=False).flatten()
ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value
......
......@@ -41,9 +41,10 @@ class SparseTensor(object):
@classmethod
def from_dense(self, mat: torch.Tensor, has_value: bool = True):
if mat.dim() > 2:
index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
index = mat.abs().sum([i for i in range(2, mat.dim())
]).nonzero(as_tuple=False)
else:
index = mat.nonzero()
index = mat.nonzero(as_tuple=False)
index = index.t()
row = index[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment