Commit 50ac1233 authored by rusty1s's avatar rusty1s
Browse files

typos

parent 57d37233
...@@ -4,12 +4,13 @@ from torch_scatter import segment_csr ...@@ -4,12 +4,13 @@ from torch_scatter import segment_csr
def __reduce__(src, dim=None, reduce='add', deterministic=False): def __reduce__(src, dim=None, reduce='add', deterministic=False):
assert reduce in ['add', 'mean', 'min', 'max']
if dim is None and src.has_value(): if dim is None and src.has_value():
func = getattr(torch, 'sum' if reduce == 'add' else reduce) func = getattr(torch, 'sum' if reduce == 'add' else reduce)
return func(src.storage.value) return func(src.storage.value)
if dim is None and not src.has_value(): if dim is None and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
value = src.nnz() if reduce == 'add' else 1 value = src.nnz() if reduce == 'add' else 1
return torch.tensor(value, device=src.device) return torch.tensor(value, device=src.device)
...@@ -26,7 +27,6 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False): ...@@ -26,7 +27,6 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
return func(value, dim=(0, ) + dense_dims) return func(value, dim=(0, ) + dense_dims)
if len(sparse_dims) == 2 and not src.has_value(): if len(sparse_dims) == 2 and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
value = src.nnz() if reduce == 'add' else 1 value = src.nnz() if reduce == 'add' else 1
return torch.tensor(value, device=src.device) return torch.tensor(value, device=src.device)
...@@ -50,15 +50,15 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False): ...@@ -50,15 +50,15 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
return out return out
if sparse_dims[0] == 1 and not src.has_value(): if sparse_dims[0] == 1 and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
if reduce == 'add': if reduce == 'add':
return src.storage.rowcount.to(torch.get_default_dtype()) return src.storage.rowcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max': elif reduce == 'min' or 'max':
# Return an additional `None` arg(min|max) tensor for consistency.
return torch.ones(src.size(0), device=src.device), None return torch.ones(src.size(0), device=src.device), None
else: else:
return torch.ones(src.size(0), device=src.device) return torch.ones(src.size(0), device=src.device)
deterministic = src.storage._csr2csc is not None or deterministic deterministic = src.storage.has_csr2csc() or deterministic
if sparse_dims[0] == 0 and deterministic and src.has_value(): if sparse_dims[0] == 0 and deterministic and src.has_value():
csr2csc = src.storage.csr2csc csr2csc = src.storage.csr2csc
...@@ -73,10 +73,10 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False): ...@@ -73,10 +73,10 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
return out return out
if sparse_dims[0] == 0 and not src.has_value(): if sparse_dims[0] == 0 and not src.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
if reduce == 'add': if reduce == 'add':
return src.storage.colcount.to(torch.get_default_dtype()) return src.storage.colcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max': elif reduce == 'min' or 'max':
# Return an additional `None` arg(min|max) tensor for consistency.
return torch.ones(src.size(1), device=src.device), None return torch.ones(src.size(1), device=src.device), None
else: else:
return torch.ones(src.size(1), device=src.device) return torch.ones(src.size(1), device=src.device)
......
...@@ -212,7 +212,7 @@ class SparseStorage(object): ...@@ -212,7 +212,7 @@ class SparseStorage(object):
@cached_property @cached_property
def colcount(self): def colcount(self):
if self._colptr is not None: if self.has_colptr():
colptr = self.colptr colptr = self.colptr
return colptr[1:] - colptr[:-1] return colptr[1:] - colptr[:-1]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment