"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3cb7b8628cbade13fe0c76aa9ff203d0844da454"
Unverified Commit 419fb815 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Use X to represent dense tensors in sddmm (#5174)



* [Sparse] Use X to represent dense tensors in sddmm

* indent

* pylint: disable=invalid-name
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent acc567aa
...@@ -6,31 +6,29 @@ from .sparse_matrix import SparseMatrix ...@@ -6,31 +6,29 @@ from .sparse_matrix import SparseMatrix
__all__ = ["sddmm", "bsddmm"] __all__ = ["sddmm", "bsddmm"]
def sddmm( # pylint: disable=invalid-name
A: SparseMatrix, mat1: torch.Tensor, mat2: torch.Tensor def sddmm(A: SparseMatrix, X1: torch.Tensor, X2: torch.Tensor) -> SparseMatrix:
) -> SparseMatrix:
r"""Sampled-Dense-Dense Matrix Multiplication (SDDMM). r"""Sampled-Dense-Dense Matrix Multiplication (SDDMM).
``sddmm`` matrix-multiplies two dense matrices :attr:`mat1` and :attr:`mat2` ``sddmm`` matrix-multiplies two dense matrices :attr:`X1` and :attr:`X2`,
, then elementwise-multiplies the result with sparse matrix :attr:`A` at the then elementwise-multiplies the result with sparse matrix :attr:`A` at the
nonzero locations. nonzero locations.
Mathematically ``sddmm`` is formulated as: Mathematically ``sddmm`` is formulated as:
.. math:: .. math::
out = (mat1 @ mat2) * A out = (X1 @ X2) * A
In particular, :attr:`mat1` and :attr:`mat2` can be 1-D, then ``mat1 @ In particular, :attr:`X1` and :attr:`X2` can be 1-D, then ``X1 @ X2``
mat2`` becomes the out-product of the two vector (which results in a becomes the out-product of the two vector (which results in a matrix).
matrix).
Parameters Parameters
---------- ----------
A : SparseMatrix A : SparseMatrix
Sparse matrix of shape ``(M, N)``. Sparse matrix of shape ``(M, N)``.
mat1 : Tensor X1 : Tensor
Dense matrix of shape ``(M, K)`` or ``(M,)`` Dense matrix of shape ``(M, K)`` or ``(M,)``
mat2 : Tensor X2 : Tensor
Dense matrix of shape ``(K, N)`` or ``(N,)`` Dense matrix of shape ``(K, N)`` or ``(N,)``
Returns Returns
...@@ -45,32 +43,29 @@ def sddmm( ...@@ -45,32 +43,29 @@ def sddmm(
>>> col = torch.tensor([2, 3, 3]) >>> col = torch.tensor([2, 3, 3])
>>> val = torch.arange(1, 4).float() >>> val = torch.arange(1, 4).float()
>>> A = from_coo(row, col, val, (3, 4)) >>> A = from_coo(row, col, val, (3, 4))
>>> mat1 = torch.randn(3, 5) >>> X1 = torch.randn(3, 5)
>>> mat2 = torch.randn(5, 4) >>> X2 = torch.randn(5, 4)
>>> dgl.sparse.sddmm(A, mat1, mat2) >>> dgl.sparse.sddmm(A, X1, X2)
SparseMatrix(indices=tensor([[1, 1, 2], SparseMatrix(indices=tensor([[1, 1, 2],
[2, 3, 3]]), [2, 3, 3]]),
values=tensor([ 1.3097, -1.0977, 1.6953]), values=tensor([ 1.3097, -1.0977, 1.6953]),
shape=(3, 4), nnz=3) shape=(3, 4), nnz=3)
""" """
return SparseMatrix( return SparseMatrix(torch.ops.dgl_sparse.sddmm(A.c_sparse_matrix, X1, X2))
torch.ops.dgl_sparse.sddmm(A.c_sparse_matrix, mat1, mat2)
)
def bsddmm( # pylint: disable=invalid-name
A: SparseMatrix, mat1: torch.Tensor, mat2: torch.Tensor def bsddmm(A: SparseMatrix, X1: torch.Tensor, X2: torch.Tensor) -> SparseMatrix:
) -> SparseMatrix:
r"""Sampled-Dense-Dense Matrix Multiplication (SDDMM) by batches. r"""Sampled-Dense-Dense Matrix Multiplication (SDDMM) by batches.
``sddmm`` multiplies two dense matrices :attr:`mat1` and :attr:`mat2` ``sddmm`` multiplies two dense matrices :attr:`X1` and :attr:`X2` at the
at the nonzero locations of sparse matrix :attr:`A`. Values of :attr:`A` nonzero locations of sparse matrix :attr:`A`. Values of :attr:`A` is not
is not considered during the computation. considered during the computation.
Mathematically ``sddmm`` is formulated as: Mathematically ``sddmm`` is formulated as:
.. math:: .. math::
out = (mat1 @ mat2) * A out = (X1 @ X2) * A
The batch dimension is the last dimension for input matrices. In particular, The batch dimension is the last dimension for input matrices. In particular,
if the sparse matrix has scalar non-zero values, it will be broadcasted if the sparse matrix has scalar non-zero values, it will be broadcasted
...@@ -80,9 +75,9 @@ def bsddmm( ...@@ -80,9 +75,9 @@ def bsddmm(
---------- ----------
A : SparseMatrix A : SparseMatrix
Sparse matrix of shape ``(M, N)`` or ``(M, N, B)``. Sparse matrix of shape ``(M, N)`` or ``(M, N, B)``.
mat1 : Tensor X1 : Tensor
Dense matrix of shape ``(M, K, B)`` Dense matrix of shape ``(M, K, B)``
mat2 : Tensor X2 : Tensor
Dense matrix of shape ``(K, N, B)`` Dense matrix of shape ``(K, N, B)``
Returns Returns
...@@ -97,9 +92,9 @@ def bsddmm( ...@@ -97,9 +92,9 @@ def bsddmm(
>>> col = torch.tensor([2, 3, 3]) >>> col = torch.tensor([2, 3, 3])
>>> val = torch.arange(1, 4).float() >>> val = torch.arange(1, 4).float()
>>> A = from_coo(row, col, val, (3, 4)) >>> A = from_coo(row, col, val, (3, 4))
>>> mat1 = torch.arange(0, 3 * 5 * 2).view(3, 5, 2).float() >>> X1 = torch.arange(0, 3 * 5 * 2).view(3, 5, 2).float()
>>> mat2 = torch.arange(0, 5 * 4 * 2).view(5, 4, 2).float() >>> X2 = torch.arange(0, 5 * 4 * 2).view(5, 4, 2).float()
>>> dgl.sparse.bsddmm(A, mat1, mat2) >>> dgl.sparse.bsddmm(A, X1, X2)
SparseMatrix(indices=tensor([[1, 1, 2], SparseMatrix(indices=tensor([[1, 1, 2],
[2, 3, 3]]), [2, 3, 3]]),
values=tensor([[1560., 1735.], values=tensor([[1560., 1735.],
...@@ -107,4 +102,4 @@ def bsddmm( ...@@ -107,4 +102,4 @@ def bsddmm(
[8400., 9105.]]), [8400., 9105.]]),
shape=(3, 4), nnz=3) shape=(3, 4), nnz=3)
""" """
return sddmm(A, mat1, mat2) return sddmm(A, X1, X2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment