Commit 50ac1233 authored by rusty1s's avatar rusty1s
Browse files

typos

parent 57d37233
......@@ -4,12 +4,13 @@ from torch_scatter import segment_csr
def __reduce__(src, dim=None, reduce='add', deterministic=False):
assert reduce in ['add', 'mean', 'min', 'max']
if dim is None and src.has_value():
func = getattr(torch, 'sum' if reduce == 'add' else reduce)
return func(src.storage.value)
if dim is None and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
value = src.nnz() if reduce == 'add' else 1
return torch.tensor(value, device=src.device)
......@@ -26,7 +27,6 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
return func(value, dim=(0, ) + dense_dims)
if len(sparse_dims) == 2 and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
value = src.nnz() if reduce == 'add' else 1
return torch.tensor(value, device=src.device)
......@@ -50,15 +50,15 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
return out
if sparse_dims[0] == 1 and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
if reduce == 'add':
return src.storage.rowcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max':
# Return an additional `None` arg(min|max) tensor for consistency.
return torch.ones(src.size(0), device=src.device), None
else:
return torch.ones(src.size(0), device=src.device)
deterministic = src.storage._csr2csc is not None or deterministic
deterministic = src.storage.has_csr2csc() or deterministic
if sparse_dims[0] == 0 and deterministic and src.has_value():
csr2csc = src.storage.csr2csc
......@@ -73,10 +73,10 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
return out
if sparse_dims[0] == 0 and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
if reduce == 'add':
return src.storage.colcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max':
# Return an additional `None` arg(min|max) tensor for consistency.
return torch.ones(src.size(1), device=src.device), None
else:
return torch.ones(src.size(1), device=src.device)
......
......@@ -212,7 +212,7 @@ class SparseStorage(object):
@cached_property
def colcount(self):
if self._colptr is not None:
if self.has_colptr():
colptr = self.colptr
return colptr[1:] - colptr[:-1]
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment