Commit b50a4861 authored by rusty1s's avatar rusty1s
Browse files

typos

parent fcf88282
......@@ -4,13 +4,13 @@ from torch_scatter import segment_csr
def reduction(src, dim=None, reduce='sum', deterministic=False):
assert reduce in ['sum', 'mean', 'min', 'max']
assert reduce in ['sum', 'add', 'mean', 'min', 'max']
if dim is None and src.has_value():
return getattr(torch, reduce)(src.storage.value)
if dim is None and not src.has_value():
value = src.nnz() if reduce == 'sum' else 1
value = src.nnz() if reduce in ['sum', 'add'] else 1
return torch.tensor(value, device=src.device)
dims = [dim] if isinstance(dim, int) else dim
......@@ -26,7 +26,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False):
return getattr(torch, reduce)(value, dim=(0, ) + dense_dims)
if len(sparse_dims) == 2 and not src.has_value():
value = src.nnz() if reduce == 'sum' else 1
value = src.nnz() if reduce in ['sum', 'add'] else 1
return torch.tensor(value, device=src.device)
if len(dense_dims) > 0 and len(sparse_dims) == 0: # src.has_value()
......@@ -47,7 +47,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False):
return out
if sparse_dims[0] == 1 and not src.has_value():
if reduce == 'sum':
if reduce in ['sum', 'add']:
return src.storage.rowcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max':
# Return an additional `None` arg(min|max) tensor for consistency.
......@@ -71,7 +71,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False):
return out
if sparse_dims[0] == 0 and not src.has_value():
if reduce == 'sum':
if reduce in ['sum', 'add']:
return src.storage.colcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max':
# Return an additional `None` arg(min|max) tensor for consistency.
......
......@@ -433,8 +433,17 @@ class SparseTensor(object):
def __iadd__(self, other):
return self.add_(other)
def __matmul__(a, b):
return matmul(a, b, reduce='sum')
def __mul__(self, other):
return self.mul(other)
def __rmul__(self, other):
return self.mul(other)
def __imul__(self, other):
return self.mul_(other)
def __matmul__(self, other):
return matmul(self, other, reduce='sum')
# String Reputation #######################################################
......@@ -479,27 +488,3 @@ SparseTensor.add = add
SparseTensor.add_ = add_
SparseTensor.add_nnz = add_nnz
SparseTensor.add_nnz_ = add_nnz_
# def __add__(self, other):
# return self.add(other)
# def __radd__(self, other):
# return self.add(other)
# def sub(self, layout=None):
# raise NotImplementedError
# def sub_(self, layout=None):
# raise NotImplementedError
# def mul(self, layout=None):
# raise NotImplementedError
# def mul_(self, layout=None):
# raise NotImplementedError
# def div(self, layout=None):
# raise NotImplementedError
# def div_(self, layout=None):
# raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment