Unverified Commit 7c15c3cd authored by Feng Shi's avatar Feng Shi Committed by GitHub
Browse files

Add function for the addition of two matrices (#177)



* Create spadd.py

Hi,
Maybe it's trivial to have this function, but I still think it'll be helpful and it looks neat when applying matrix addition, i.e., C = A + B.
Thanks

* update

* update

* fix jit
Co-authored-by: default avatarrusty1s <matthias.fey@tu-dortmund.de>
parent 28f12953
from itertools import product
import pytest
import torch
from torch_sparse import SparseTensor, add
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_add(dtype, device):
rowA = torch.tensor([0, 0, 1, 2, 2], device=device)
colA = torch.tensor([0, 2, 1, 0, 1], device=device)
valueA = tensor([1, 2, 4, 1, 3], dtype, device)
A = SparseTensor(row=rowA, col=colA, value=valueA)
rowB = torch.tensor([0, 0, 1, 2, 2], device=device)
colB = torch.tensor([1, 2, 2, 1, 2], device=device)
valueB = tensor([2, 3, 1, 2, 4], dtype, device)
B = SparseTensor(row=rowB, col=colB, value=valueB)
C = A + B
rowC, colC, valueC = C.coo()
assert rowC.tolist() == [0, 0, 0, 1, 1, 2, 2, 2]
assert colC.tolist() == [0, 1, 2, 1, 2, 0, 1, 2]
assert valueC.tolist() == [1, 2, 5, 4, 1, 1, 5, 4]
@torch.jit.script
def jit_add(A: SparseTensor, B: SparseTensor) -> SparseTensor:
return add(A, B)
jit_add(A, B)
......@@ -65,6 +65,7 @@ from .transpose import transpose # noqa
from .eye import eye # noqa
from .spmm import spmm # noqa
from .spspmm import spspmm # noqa
from .spadd import spadd # noqa
__all__ = [
'SparseStorage',
......@@ -111,5 +112,6 @@ __all__ = [
'eye',
'spmm',
'spspmm',
'spadd',
'__version__',
]
from typing import Optional
import torch
from torch import Tensor
from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor
def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
@torch.jit._overload # noqa: F811
def add(src, other): # noqa: F811
# type: (SparseTensor, Tensor) -> SparseTensor
pass
@torch.jit._overload # noqa: F811
def add(src, other): # noqa: F811
# type: (SparseTensor, SparseTensor) -> SparseTensor
pass
def add(src, other): # noqa: F811
if isinstance(other, Tensor):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise.
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise.
other = other.squeeze(0)[col]
else:
raise ValueError(
......@@ -22,13 +35,35 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
value = other.add_(1)
return src.set_value(value, layout='coo')
elif isinstance(other, SparseTensor):
rowA, colA, valueA = src.coo()
rowB, colB, valueB = other.coo()
row = torch.cat([rowA, rowB], dim=0)
col = torch.cat([colA, colB], dim=0)
value: Optional[Tensor] = None
if valueA is not None and valueB is not None:
value = torch.cat([valueA, valueB], dim=0)
M = max(src.size(0), other.size(0))
N = max(src.size(1), other.size(1))
sparse_sizes = (M, N)
out = SparseTensor(row=row, col=col, value=value,
sparse_sizes=sparse_sizes)
out = out.coalesce(reduce='sum')
return out
else:
raise NotImplementedError
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise.
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise.
other = other.squeeze(0)[col]
else:
raise ValueError(
......
import torch
from torch_sparse import coalesce
def spadd(indexA, valueA, indexB, valueB, m, n):
"""Matrix addition of two sparse matrices.
Args:
indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
valueA (:class:`Tensor`): The value tensor of first sparse matrix.
indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
valueB (:class:`Tensor`): The value tensor of second sparse matrix.
m (int): The first dimension of the sparse matrices.
n (int): The second dimension of the sparse matrices.
"""
index = torch.cat([indexA, indexB], dim=-1)
value = torch.cat([valueA, valueB], dim=0)
return coalesce(index=index, value=value, m=m, n=n, op='add')
......@@ -292,7 +292,7 @@ class SparseStorage(object):
idx = self.sparse_size(1) * self.row() + self.col()
row = idx // num_cols
row = torch.div(idx, num_cols, rounding_mode='floor')
col = idx % num_cols
assert row.dtype == torch.long and col.dtype == torch.long
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment