Unverified Commit b5c5c860 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Refactor matmul interface. (#5162)

* [Sparse] Refactor matmul interface.

* Update
parent 9334421d
...@@ -171,10 +171,10 @@ Matrix Multiplication ...@@ -171,10 +171,10 @@ Matrix Multiplication
.. autosummary:: .. autosummary::
:toctree: ../../generated/ :toctree: ../../generated/
matmul
spmm spmm
bspmm bspmm
spspmm spspmm
mm
sddmm sddmm
bsddmm bsddmm
......
...@@ -8,11 +8,11 @@ from .diag_matrix import diag, DiagMatrix ...@@ -8,11 +8,11 @@ from .diag_matrix import diag, DiagMatrix
from .sparse_matrix import SparseMatrix, val_like from .sparse_matrix import SparseMatrix, val_like
__all__ = ["spmm", "bspmm", "spspmm", "mm"] __all__ = ["spmm", "bspmm", "spspmm", "matmul"]
def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
"""Multiply a sparse matrix by a dense matrix. """Multiply a sparse matrix by a dense matrix, equivalent to ``A @ X``.
Parameters Parameters
---------- ----------
...@@ -54,7 +54,8 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: ...@@ -54,7 +54,8 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
"""Multiply a sparse matrix by a dense matrix by batches. """Multiply a sparse matrix by a dense matrix by batches, equivalent to
``A @ X``.
Parameters Parameters
---------- ----------
...@@ -91,14 +92,14 @@ def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: ...@@ -91,14 +92,14 @@ def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
return spmm(A, X) return spmm(A, X)
def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix: def _diag_diag_mm(A: DiagMatrix, B: DiagMatrix) -> DiagMatrix:
"""Internal function for multiplying a diagonal matrix by a diagonal matrix """Internal function for multiplying a diagonal matrix by a diagonal matrix
Parameters Parameters
---------- ----------
A1 : DiagMatrix A : DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1) Matrix of shape (N, M), with values of shape (nnz1)
A2 : DiagMatrix B : DiagMatrix
Matrix of shape (M, P), with values of shape (nnz2) Matrix of shape (M, P), with values of shape (nnz2)
Returns Returns
...@@ -106,15 +107,15 @@ def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix: ...@@ -106,15 +107,15 @@ def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix:
DiagMatrix DiagMatrix
The result of multiplication. The result of multiplication.
""" """
M, N = A1.shape M, N = A.shape
N, P = A2.shape N, P = B.shape
common_diag_len = min(M, N, P) common_diag_len = min(M, N, P)
new_diag_len = min(M, P) new_diag_len = min(M, P)
diag_val = torch.zeros(new_diag_len) diag_val = torch.zeros(new_diag_len)
diag_val[:common_diag_len] = ( diag_val[:common_diag_len] = (
A1.val[:common_diag_len] * A2.val[:common_diag_len] A.val[:common_diag_len] * B.val[:common_diag_len]
) )
return diag(diag_val.to(A1.device), (M, P)) return diag(diag_val.to(A.device), (M, P))
def _sparse_diag_mm(A, D): def _sparse_diag_mm(A, D):
...@@ -174,16 +175,17 @@ def _diag_sparse_mm(D, A): ...@@ -174,16 +175,17 @@ def _diag_sparse_mm(D, A):
def spspmm( def spspmm(
A1: Union[SparseMatrix, DiagMatrix], A2: Union[SparseMatrix, DiagMatrix] A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix]
) -> Union[SparseMatrix, DiagMatrix]: ) -> Union[SparseMatrix, DiagMatrix]:
"""Multiply a sparse matrix by a sparse matrix. The non-zero values of the """Multiply a sparse matrix by a sparse matrix, equivalent to ``A @ B``.
two sparse matrices must be 1D.
The non-zero values of the two sparse matrices must be 1D.
Parameters Parameters
---------- ----------
A1 : SparseMatrix or DiagMatrix A : SparseMatrix or DiagMatrix
Sparse matrix of shape (N, M) with values of shape (nnz) Sparse matrix of shape (N, M) with values of shape (nnz)
A2 : SparseMatrix or DiagMatrix B : SparseMatrix or DiagMatrix
Sparse matrix of shape (M, P) with values of shape (nnz) Sparse matrix of shape (M, P) with values of shape (nnz)
Returns Returns
...@@ -198,13 +200,13 @@ def spspmm( ...@@ -198,13 +200,13 @@ def spspmm(
>>> row1 = torch.tensor([0, 1, 1]) >>> row1 = torch.tensor([0, 1, 1])
>>> col1 = torch.tensor([1, 0, 1]) >>> col1 = torch.tensor([1, 0, 1])
>>> val1 = torch.ones(len(row1)) >>> val1 = torch.ones(len(row1))
>>> A1 = from_coo(row1, col1, val1) >>> A = from_coo(row1, col1, val1)
>>> row2 = torch.tensor([0, 1, 1]) >>> row2 = torch.tensor([0, 1, 1])
>>> col2 = torch.tensor([0, 2, 1]) >>> col2 = torch.tensor([0, 2, 1])
>>> val2 = torch.ones(len(row2)) >>> val2 = torch.ones(len(row2))
>>> A2 = from_coo(row2, col2, val2) >>> B = from_coo(row2, col2, val2)
>>> result = dgl.sparse.spspmm(A1, A2) >>> result = dgl.sparse.spspmm(A, B)
>>> print(result) >>> print(result)
SparseMatrix(indices=tensor([[0, 0, 1, 1, 1], SparseMatrix(indices=tensor([[0, 0, 1, 1, 1],
[1, 2, 0, 1, 2]]), [1, 2, 0, 1, 2]]),
...@@ -212,73 +214,134 @@ def spspmm( ...@@ -212,73 +214,134 @@ def spspmm(
shape=(2, 3), nnz=5) shape=(2, 3), nnz=5)
""" """
assert isinstance( assert isinstance(
A1, (SparseMatrix, DiagMatrix) A, (SparseMatrix, DiagMatrix)
), f"Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A1)}" ), f"Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A)}"
assert isinstance( assert isinstance(
A2, (SparseMatrix, DiagMatrix) B, (SparseMatrix, DiagMatrix)
), f"Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(A2)}" ), f"Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(B)}"
if isinstance(A1, DiagMatrix) and isinstance(A2, DiagMatrix): if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
return _diag_diag_mm(A1, A2) return _diag_diag_mm(A, B)
if isinstance(A1, DiagMatrix): if isinstance(A, DiagMatrix):
return _diag_sparse_mm(A1, A2) return _diag_sparse_mm(A, B)
if isinstance(A2, DiagMatrix): if isinstance(B, DiagMatrix):
return _sparse_diag_mm(A1, A2) return _sparse_diag_mm(A, B)
return SparseMatrix( return SparseMatrix(
torch.ops.dgl_sparse.spspmm(A1.c_sparse_matrix, A2.c_sparse_matrix) torch.ops.dgl_sparse.spspmm(A.c_sparse_matrix, B.c_sparse_matrix)
) )
def mm( def matmul(
A1: Union[SparseMatrix, DiagMatrix], A: Union[torch.Tensor, SparseMatrix, DiagMatrix],
A2: Union[torch.Tensor, SparseMatrix, DiagMatrix], B: Union[torch.Tensor, SparseMatrix, DiagMatrix],
) -> Union[torch.Tensor, SparseMatrix, DiagMatrix]: ) -> Union[torch.Tensor, SparseMatrix, DiagMatrix]:
"""Multiply a sparse/diagonal matrix by a dense/sparse/diagonal matrix. """Multiply two dense/sparse/diagonal matrices, equivalent to ``A @ B``.
If an input is a SparseMatrix or DiagMatrix, its non-zero values should
be 1-D. The supported combinations are shown as follows.
+--------------+--------+------------+--------------+
| A \\ B | Tensor | DiagMatrix | SparseMatrix |
+--------------+--------+------------+--------------+
| Tensor | ✅ | 🚫 | 🚫 |
+--------------+--------+------------+--------------+
| SparseMatrix | ✅ | ✅ | ✅ |
+--------------+--------+------------+--------------+
| DiagMatrix | ✅ | ✅ | ✅ |
+--------------+--------+------------+--------------+
* If both matrices are torch.Tensor, it calls \
:func:`torch.matmul()`. The result is a dense matrix.
* If both matrices are sparse or diagonal, it calls \
:func:`dgl.sparse.spspmm`. The result is a sparse matrix.
* If :attr:`A` is sparse or diagonal while :attr:`B` is dense, it \
calls :func:`dgl.sparse.spmm`. The result is a dense matrix.
* The operator supports batched sparse-dense matrix multiplication. In \
this case, the sparse or diagonal matrix :attr:`A` should have shape \
:math:`(L, M)`, where the non-zero values have a batch dimension \
:math:`K`. The dense matrix :attr:`B` should have shape \
:math:`(M, N, K)`. The output is a dense matrix of shape \
:math:`(L, N, K)`.
* Sparse-sparse matrix multiplication does not support batched computation.
Parameters Parameters
---------- ----------
A1 : SparseMatrix or DiagMatrix A : torch.Tensor, SparseMatrix or DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1) The first matrix.
A2 : torch.Tensor, SparseMatrix, or DiagMatrix B : torch.Tensor, SparseMatrix, or DiagMatrix
Matrix of shape (M, P). If it is a SparseMatrix or DiagMatrix, The second matrix.
it should have values of shape (nnz2).
Returns Returns
------- -------
torch.Tensor or DiagMatrix or SparseMatrix torch.Tensor, SparseMatrix or DiagMatrix
The result of multiplication of shape (N, P) The result matrix
* It is a dense torch tensor if :attr:`A2` is so.
* It is a DiagMatrix object if both :attr:`A1` and :attr:`A2` are so.
* It is a SparseMatrix object otherwise.
Examples Examples
-------- --------
Multiply a diagonal matrix with a dense matrix.
>>> val = torch.randn(3) >>> val = torch.randn(3)
>>> A1 = diag(val) >>> A = diag(val)
>>> A2 = torch.randn(3, 2) >>> B = torch.randn(3, 2)
>>> result = dgl.sparse.mm(A1, A2) >>> result = dgl.sparse.matmul(A, B)
>>> print(type(result)) >>> print(type(result))
<class 'torch.Tensor'> <class 'torch.Tensor'>
>>> print(result.shape) >>> print(result.shape)
torch.Size([3, 2]) torch.Size([3, 2])
Multiply a sparse matrix with a dense matrix.
>>> row = torch.tensor([0, 1, 1])
>>> col = torch.tensor([1, 0, 1])
>>> val = torch.randn(len(row))
>>> A = from_coo(row, col, val)
>>> X = torch.randn(2, 3)
>>> result = dgl.sparse.matmul(A, X)
>>> print(type(result))
<class 'torch.Tensor'>
>>> print(result.shape)
torch.Size([2, 3])
Multiply a sparse matrix with a sparse matrix.
>>> row1 = torch.tensor([0, 1, 1])
>>> col1 = torch.tensor([1, 0, 1])
>>> val1 = torch.ones(len(row1))
>>> A = from_coo(row1, col1, val1)
>>> row2 = torch.tensor([0, 1, 1])
>>> col2 = torch.tensor([0, 2, 1])
>>> val2 = torch.ones(len(row2))
>>> B = from_coo(row2, col2, val2)
>>> result = dgl.sparse.matmul(A, B)
>>> print(type(result))
<class 'dgl.sparse.sparse_matrix.SparseMatrix'>
>>> print(result.shape)
(2, 3)
""" """
assert isinstance( assert isinstance(A, (torch.Tensor, SparseMatrix, DiagMatrix)), (
A1, (SparseMatrix, DiagMatrix) f"Expect arg1 to be a torch.Tensor, SparseMatrix, or DiagMatrix object,"
), f"Expect arg1 to be a SparseMatrix, or DiagMatrix object, got {type(A1)}." f"got {type(A)}."
assert isinstance(A2, (torch.Tensor, SparseMatrix, DiagMatrix)), ( )
assert isinstance(B, (torch.Tensor, SparseMatrix, DiagMatrix)), (
f"Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix" f"Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix"
f"object, got {type(A2)}." f"object, got {type(B)}."
)
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
return torch.matmul(A, B)
assert not isinstance(A, torch.Tensor), (
f"Expect arg2 to be a torch Tensor if arg 1 is torch Tensor, "
f"got {type(B)}."
) )
if isinstance(A2, torch.Tensor): if isinstance(B, torch.Tensor):
return spmm(A1, A2) return spmm(A, B)
if isinstance(A1, DiagMatrix) and isinstance(A2, DiagMatrix): if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
return _diag_diag_mm(A1, A2) return _diag_diag_mm(A, B)
return spspmm(A1, A2) return spspmm(A, B)
SparseMatrix.__matmul__ = mm SparseMatrix.__matmul__ = matmul
DiagMatrix.__matmul__ = mm DiagMatrix.__matmul__ = matmul
...@@ -4,7 +4,8 @@ import backend as F ...@@ -4,7 +4,8 @@ import backend as F
import pytest import pytest
import torch import torch
from dgl.sparse import bspmm, diag, from_coo, mm, val_like from dgl.sparse import bspmm, diag, from_coo, val_like
from dgl.sparse.matmul import matmul
from .utils import ( from .utils import (
clone_detach_and_grad, clone_detach_and_grad,
...@@ -33,7 +34,7 @@ def test_spmm(create_func, shape, nnz, out_dim): ...@@ -33,7 +34,7 @@ def test_spmm(create_func, shape, nnz, out_dim):
else: else:
X = torch.randn(shape[1], requires_grad=True, device=dev) X = torch.randn(shape[1], requires_grad=True, device=dev)
sparse_result = A @ X sparse_result = matmul(A, X)
grad = torch.randn_like(sparse_result) grad = torch.randn_like(sparse_result)
sparse_result.backward(grad) sparse_result.backward(grad)
...@@ -60,7 +61,7 @@ def test_bspmm(create_func, shape, nnz): ...@@ -60,7 +61,7 @@ def test_bspmm(create_func, shape, nnz):
A = create_func(shape, nnz, dev, 2) A = create_func(shape, nnz, dev, 2)
X = torch.randn(shape[1], 10, 2, requires_grad=True, device=dev) X = torch.randn(shape[1], 10, 2, requires_grad=True, device=dev)
sparse_result = bspmm(A, X) sparse_result = matmul(A, X)
grad = torch.randn_like(sparse_result) grad = torch.randn_like(sparse_result)
sparse_result.backward(grad) sparse_result.backward(grad)
...@@ -92,7 +93,7 @@ def test_spspmm(create_func1, create_func2, shape_n_m, shape_k, nnz1, nnz2): ...@@ -92,7 +93,7 @@ def test_spspmm(create_func1, create_func2, shape_n_m, shape_k, nnz1, nnz2):
shape2 = (shape_n_m[1], shape_k) shape2 = (shape_n_m[1], shape_k)
A1 = create_func1(shape1, nnz1, dev) A1 = create_func1(shape1, nnz1, dev)
A2 = create_func2(shape2, nnz2, dev) A2 = create_func2(shape2, nnz2, dev)
A3 = A1 @ A2 A3 = matmul(A1, A2)
grad = torch.randn_like(A3.val) grad = torch.randn_like(A3.val)
A3.val.backward(grad) A3.val.backward(grad)
...@@ -132,14 +133,14 @@ def test_spspmm_duplicate(): ...@@ -132,14 +133,14 @@ def test_spspmm_duplicate():
A2 = from_coo(row, col, val, shape) A2 = from_coo(row, col, val, shape)
try: try:
A1 @ A2 matmul(A1, A2)
except: except:
pass pass
else: else:
assert False, "Should raise error." assert False, "Should raise error."
try: try:
A2 @ A1 matmul(A2, A1)
except: except:
pass pass
else: else:
...@@ -155,8 +156,7 @@ def test_sparse_diag_mm(create_func, sparse_shape, nnz): ...@@ -155,8 +156,7 @@ def test_sparse_diag_mm(create_func, sparse_shape, nnz):
A = create_func(sparse_shape, nnz, dev) A = create_func(sparse_shape, nnz, dev)
diag_val = torch.randn(sparse_shape[1], device=dev, requires_grad=True) diag_val = torch.randn(sparse_shape[1], device=dev, requires_grad=True)
D = diag(diag_val, diag_shape) D = diag(diag_val, diag_shape)
# (TODO) Need to use dgl.sparse.matmul after rename mm to matmul B = matmul(A, D)
B = mm(A, D)
grad = torch.randn_like(B.val) grad = torch.randn_like(B.val)
B.val.backward(grad) B.val.backward(grad)
...@@ -189,8 +189,7 @@ def test_diag_sparse_mm(create_func, sparse_shape, nnz): ...@@ -189,8 +189,7 @@ def test_diag_sparse_mm(create_func, sparse_shape, nnz):
A = create_func(sparse_shape, nnz, dev) A = create_func(sparse_shape, nnz, dev)
diag_val = torch.randn(sparse_shape[0], device=dev, requires_grad=True) diag_val = torch.randn(sparse_shape[0], device=dev, requires_grad=True)
D = diag(diag_val, diag_shape) D = diag(diag_val, diag_shape)
# (TODO) Need to use dgl.sparse.matmul after rename mm to matmul B = matmul(D, A)
B = mm(D, A)
grad = torch.randn_like(B.val) grad = torch.randn_like(B.val)
B.val.backward(grad) B.val.backward(grad)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment