Commit 8819288a authored by rusty1s's avatar rusty1s
Browse files

overload fix

parent 73146b9b
import torch
from torch_sparse.tensor import SparseTensor
def test_overload():
row = torch.tensor([0, 1, 1, 2, 2])
col = torch.tensor([1, 0, 2, 1, 2])
mat = SparseTensor(row=row, col=col)
other = torch.tensor([1, 2, 3]).view(3, 1)
other + mat
mat + other
other * mat
mat * other
other = torch.tensor([1, 2, 3]).view(1, 3)
other + mat
mat + other
other * mat
mat * other
......@@ -2,6 +2,10 @@ import torch
from torch_scatter import gather_csr
def is_scalar(other):
return isinstance(other, int) or isinstance(other, float)
def sparse_add(matA, matB):
nnzA, nnzB = matA.nnz(), matB.nnz()
valA = torch.full((nnzA, ), 1, dtype=torch.uint8, device=matA.device)
......@@ -38,7 +42,7 @@ def sparse_add(matA, matB):
def add(src, other):
if isinstance(other, int) or isinstance(other, float):
if is_scalar(other):
return add_nnz(src, other)
elif torch.is_tensor(other):
......@@ -65,7 +69,7 @@ def add(src, other):
def add_(src, other):
if isinstance(other, int) or isinstance(other, float):
if is_scalar(other):
return add_nnz_(src, other)
elif torch.is_tensor(other):
......@@ -98,7 +102,7 @@ def add_(src, other):
def add_nnz(src, other, layout=None):
if isinstance(other, int) or isinstance(other, float):
if is_scalar(other):
if src.has_value():
value = src.storage.value + other
else:
......@@ -117,7 +121,7 @@ def add_nnz(src, other, layout=None):
def add_nnz_(src, other, layout=None):
if isinstance(other, int) or isinstance(other, float):
if is_scalar(other):
if src.has_value():
value = src.storage.value.add_(other)
else:
......
import torch
from torch_scatter import gather_csr
def is_scalar(other):
return isinstance(other, int) or isinstance(other, float)
def mul(src, other):
if is_scalar(other):
return mul_nnz(src, other)
elif torch.is_tensor(other):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
if src.has_value():
value = other.mul_(src.storage.value)
else:
value = other
return src.set_value(value, layout='csr')
if other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
if src.has_value():
value = other.mul_(src.storage.value)
else:
value = other
return src.set_value(value, layout='coo')
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.')
elif isinstance(other, src.__class__):
raise NotImplementedError
raise ValueError('Argument `other` needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.')
def mul_(src, other):
if is_scalar(other):
return mul_nnz_(src, other)
elif torch.is_tensor(other):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
if src.has_value():
value = src.storage.value.mul_(other)
else:
value = other
return src.set_value_(value, layout='csr')
if other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
if src.has_value():
value = src.storage.value.mul_(other)
else:
value = other
return src.set_value_(value, layout='coo')
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.')
elif isinstance(other, src.__class__):
raise NotImplementedError
raise ValueError('Argument `other` needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.')
def mul_nnz(src, other, layout=None):
if torch.is_tensor(other) or is_scalar(other):
if src.has_value():
value = src.storage.value * other
else:
value = other
return src.set_value(value, layout='coo')
raise ValueError('Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.')
def mul_nnz_(src, other, layout=None):
if torch.is_tensor(other) or is_scalar(other):
if src.has_value():
value = src.storage.value.mul_(other)
else:
value = other
return src.set_value_(value, layout='coo')
raise ValueError('Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.')
......@@ -14,6 +14,7 @@ import torch_sparse.reduce
from torch_sparse.diag import remove_diag, set_diag
from torch_sparse.matmul import matmul
from torch_sparse.add import add, add_, add_nnz, add_nnz_
from torch_sparse.mul import mul, mul_, mul_nnz, mul_nnz_
class SparseTensor(object):
......@@ -455,7 +456,7 @@ class SparseTensor(object):
infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
if self.has_value():
infos += [f'value={indent(value.__repr__(), i)[len(i):]}']
infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
infos += [
f'size={tuple(self.size())}, '
......@@ -482,10 +483,28 @@ SparseTensor.sum = torch_sparse.reduce.sum
SparseTensor.mean = torch_sparse.reduce.mean
SparseTensor.min = torch_sparse.reduce.min
SparseTensor.max = torch_sparse.reduce.max
SparseTensor.remove_diag = remove_diag #TODO
SparseTensor.set_diag = set_diag #TODO
SparseTensor.matmul = matmul # TODO
SparseTensor.remove_diag = remove_diag
SparseTensor.set_diag = set_diag
SparseTensor.matmul = matmul
SparseTensor.add = add
SparseTensor.add_ = add_
SparseTensor.add_nnz = add_nnz
SparseTensor.add_nnz_ = add_nnz_
SparseTensor.mul = mul
SparseTensor.mul_ = mul_
SparseTensor.mul_nnz = mul_nnz
SparseTensor.mul_nnz_ = mul_nnz_
# Fix for PyTorch<=1.3 (https://github.com/pytorch/pytorch/pull/31769):
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR <= 1) or (TORCH_MAJOR == 1 and TORCH_MINOR < 4):
def add(self, other):
return self.add(other) if torch.is_tensor(other) else NotImplemented
def mul(self, other):
return self.mul(other) if torch.is_tensor(other) else NotImplemented
torch.Tensor.__add__ = add
torch.Tensor.__mul__ = add
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment