Unverified Commit b57c56d9 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Support broadcasting operators (#5544)

parent 7c465d20
...@@ -148,3 +148,15 @@ Non-linear activation functions ...@@ -148,3 +148,15 @@ Non-linear activation functions
:toctree: ../../generated/ :toctree: ../../generated/
softmax softmax
Broadcast operators
````````
.. autosummary::
:toctree: ../../generated/
sp_broadcast_v
sp_add_v
sp_sub_v
sp_mul_v
sp_div_v
\ No newline at end of file
...@@ -5,6 +5,7 @@ import sys ...@@ -5,6 +5,7 @@ import sys
import torch import torch
from .._ffi import libinfo from .._ffi import libinfo
from .broadcast import *
from .elementwise_op import * from .elementwise_op import *
from .elementwise_op_sp import * from .elementwise_op_sp import *
from .matmul import * from .matmul import *
......
"""DGL broadcast operator module."""
import operator
import torch
from .sparse_matrix import SparseMatrix, val_like
def sp_broadcast_v(A: SparseMatrix, v: torch.Tensor, op: str) -> SparseMatrix:
"""Broadcast operator for sparse matrix and vector.
:attr:`v` is broadcasted to the shape of :attr:`A` and then the operator is
applied on the non-zero values of :attr:`A`.
There are two cases regarding the shape of v:
1. :attr:`v` is a vector of shape ``(1, A.shape[1])`` or ``(A.shape[1])``.
In this case, :attr:`v` is broadcasted on the row dimension of :attr:`A`.
2. :attr:`v` is a vector of shape ``(A.shape[0], 1)``. In this case,
:attr:`v` is broadcasted on the column dimension of :attr:`A`.
If ``A.val`` takes shape ``(nnz, D)``, then :attr:`v` will be broadcasted on
the ``D`` dimension.
Parameters
----------
A: SparseMatrix
Sparse matrix
v: torch.Tensor
Vector
op: str
Operator in ["add", "sub", "mul", "truediv"]
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
>>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])
>>> val = torch.tensor([10, 20, 30])
>>> A = dglsp.spmatrix(indices, val, shape=(3, 4))
>>> v = torch.tensor([1, 2, 3, 4])
>>> dglsp.sp_broadcast_v(A, v, "add")
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([11, 24, 33]),
shape=(3, 4), nnz=3)
>>> v = torch.tensor([1, 2, 3]).view(-1, 1)
>>> dglsp.sp_broadcast_v(A, v, "add")
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([12, 21, 33]),
shape=(3, 4), nnz=3)
>>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])
>>> val = torch.tensor([[10, 20], [30, 40], [50, 60]])
>>> A = dglsp.spmatrix(indices, val, shape=(3, 4))
>>> v = torch.tensor([1, 2, 3]).view(-1, 1)
>>> dglsp.sp_broadcast_v(A, v, "sub")
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([[ 8, 18],
[29, 39],
[47, 57]]),
shape=(3, 4), nnz=3, val_size=(2,))
"""
op = getattr(operator, op)
if v.dim() == 1:
v = v.view(1, -1)
shape_error_message = (
f"Dimension mismatch for broadcasting. Got A.shape = {A.shape} and"
f"v.shape = {v.shape}."
)
assert v.dim() <= 2 and (1 in v.shape), shape_error_message
broadcast_dim = None
# v can be broadcasted to A if exactly one dimension of v is 1 and the other
# is the same as A.
for d, (dim1, dim2) in enumerate(zip(A.shape, v.shape)):
assert dim2 in (1, dim1), shape_error_message
if dim1 != dim2:
assert broadcast_dim is None, shape_error_message
broadcast_dim = d
# A and v has the same shape of (1, *) or (*, 1).
if broadcast_dim is None:
broadcast_dim = 0 if A.shape[0] == 1 else 1
if broadcast_dim == 0:
v = v.view(-1)[A.col]
else:
v = v.view(-1)[A.row]
if A.val.dim() > 1:
v = v.view(-1, 1)
ret_val = op(A.val, v)
return val_like(A, ret_val)
def sp_add_v(A: SparseMatrix, v: torch.Tensor) -> SparseMatrix:
"""Broadcast addition for sparse matrix and vector.
See the definition of :func:`sp_broadcast_v` for details.
"""
return sp_broadcast_v(A, v, "add")
def sp_sub_v(A: SparseMatrix, v: torch.Tensor) -> SparseMatrix:
"""Broadcast substraction for sparse matrix and vector.
See the definition of :func:`sp_broadcast_v` for details.
"""
return sp_broadcast_v(A, v, "sub")
def sp_mul_v(A: SparseMatrix, v: torch.Tensor) -> SparseMatrix:
"""Broadcast multiply for sparse matrix and vector.
See the definition of :func:`sp_broadcast_v` for details.
"""
return sp_broadcast_v(A, v, "mul")
def sp_div_v(A: SparseMatrix, v: torch.Tensor) -> SparseMatrix:
"""Broadcast division for sparse matrix and vector.
See the definition of :func:`sp_broadcast_v` for details.
"""
return sp_broadcast_v(A, v, "truediv")
import operator
import backend as F
import pytest
import torch
from dgl.sparse import sp_broadcast_v
from .utils import rand_coo
@pytest.mark.parametrize("shape", [(3, 4), (1, 5), (5, 1)])
@pytest.mark.parametrize("nnz", [1, 4])
@pytest.mark.parametrize("nz_dim", [None, 2])
@pytest.mark.parametrize("op", ["add", "sub", "mul", "truediv"])
def test_sp_broadcast_v(shape, nnz, nz_dim, op):
dev = F.ctx()
A = rand_coo(shape, nnz, dev, nz_dim)
v = torch.randn(A.shape[1], device=dev)
res1 = sp_broadcast_v(A, v, op)
if A.val.dim() == 1:
rhs = v[A.col]
else:
rhs = v[A.col].view(-1, 1)
res2 = getattr(operator, op)(A.val, rhs)
assert torch.allclose(res1.val, res2)
v = torch.randn(1, A.shape[1], device=dev)
res1 = sp_broadcast_v(A, v, op)
if A.val.dim() == 1:
rhs = v.view(-1)[A.col]
else:
rhs = v.view(-1)[A.col].view(-1, 1)
res2 = getattr(operator, op)(A.val, rhs)
assert torch.allclose(res1.val, res2)
v = torch.randn(A.shape[0], 1, device=dev)
res1 = sp_broadcast_v(A, v, op)
if A.val.dim() == 1:
rhs = v.view(-1)[A.row]
else:
rhs = v.view(-1)[A.row].view(-1, 1)
res2 = getattr(operator, op)(A.val, rhs)
assert torch.allclose(res1.val, res2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment