Commit b50a4861 authored by rusty1s's avatar rusty1s
Browse files

typos

parent fcf88282
...@@ -4,13 +4,13 @@ from torch_scatter import segment_csr ...@@ -4,13 +4,13 @@ from torch_scatter import segment_csr
def reduction(src, dim=None, reduce='sum', deterministic=False): def reduction(src, dim=None, reduce='sum', deterministic=False):
assert reduce in ['sum', 'mean', 'min', 'max'] assert reduce in ['sum', 'add', 'mean', 'min', 'max']
if dim is None and src.has_value(): if dim is None and src.has_value():
return getattr(torch, reduce)(src.storage.value) return getattr(torch, reduce)(src.storage.value)
if dim is None and not src.has_value(): if dim is None and not src.has_value():
value = src.nnz() if reduce == 'sum' else 1 value = src.nnz() if reduce in ['sum', 'add'] else 1
return torch.tensor(value, device=src.device) return torch.tensor(value, device=src.device)
dims = [dim] if isinstance(dim, int) else dim dims = [dim] if isinstance(dim, int) else dim
...@@ -26,7 +26,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False): ...@@ -26,7 +26,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False):
return getattr(torch, reduce)(value, dim=(0, ) + dense_dims) return getattr(torch, reduce)(value, dim=(0, ) + dense_dims)
if len(sparse_dims) == 2 and not src.has_value(): if len(sparse_dims) == 2 and not src.has_value():
value = src.nnz() if reduce == 'sum' else 1 value = src.nnz() if reduce in ['sum', 'add'] else 1
return torch.tensor(value, device=src.device) return torch.tensor(value, device=src.device)
if len(dense_dims) > 0 and len(sparse_dims) == 0: # src.has_value() if len(dense_dims) > 0 and len(sparse_dims) == 0: # src.has_value()
...@@ -47,7 +47,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False): ...@@ -47,7 +47,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False):
return out return out
if sparse_dims[0] == 1 and not src.has_value(): if sparse_dims[0] == 1 and not src.has_value():
if reduce == 'sum': if reduce in ['sum', 'add']:
return src.storage.rowcount.to(torch.get_default_dtype()) return src.storage.rowcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max': elif reduce == 'min' or 'max':
# Return an additional `None` arg(min|max) tensor for consistency. # Return an additional `None` arg(min|max) tensor for consistency.
...@@ -71,7 +71,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False): ...@@ -71,7 +71,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False):
return out return out
if sparse_dims[0] == 0 and not src.has_value(): if sparse_dims[0] == 0 and not src.has_value():
if reduce == 'sum': if reduce in ['sum', 'add']:
return src.storage.colcount.to(torch.get_default_dtype()) return src.storage.colcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max': elif reduce == 'min' or 'max':
# Return an additional `None` arg(min|max) tensor for consistency. # Return an additional `None` arg(min|max) tensor for consistency.
......
...@@ -433,8 +433,17 @@ class SparseTensor(object): ...@@ -433,8 +433,17 @@ class SparseTensor(object):
def __iadd__(self, other): def __iadd__(self, other):
return self.add_(other) return self.add_(other)
def __matmul__(a, b): def __mul__(self, other):
return matmul(a, b, reduce='sum') return self.mul(other)
def __rmul__(self, other):
return self.mul(other)
def __imul__(self, other):
return self.mul_(other)
def __matmul__(self, other):
return matmul(self, other, reduce='sum')
# String Reputation ####################################################### # String Reputation #######################################################
...@@ -479,27 +488,3 @@ SparseTensor.add = add ...@@ -479,27 +488,3 @@ SparseTensor.add = add
SparseTensor.add_ = add_ SparseTensor.add_ = add_
SparseTensor.add_nnz = add_nnz SparseTensor.add_nnz = add_nnz
SparseTensor.add_nnz_ = add_nnz_ SparseTensor.add_nnz_ = add_nnz_
# def __add__(self, other):
# return self.add(other)
# def __radd__(self, other):
# return self.add(other)
# def sub(self, layout=None):
# raise NotImplementedError
# def sub_(self, layout=None):
# raise NotImplementedError
# def mul(self, layout=None):
# raise NotImplementedError
# def mul_(self, layout=None):
# raise NotImplementedError
# def div(self, layout=None):
# raise NotImplementedError
# def div_(self, layout=None):
# raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment