Unverified Commit b3aec7ae authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Add mock_sddmm in mock_sparse (#5059)

parent 56ce60b0
"""Sampled Dense-Dense Matrix Multiplication (SDDMM) operator module.""" """Sampled Dense-Dense Matrix Multiplication (SDDMM) operator module."""
import torch import torch
from .sp_matrix import SparseMatrix from .sp_matrix import create_from_coo, SparseMatrix
__all__ = ["sddmm"] __all__ = ["sddmm", "mock_bsddmm"]
def sddmm( def sddmm(
...@@ -56,3 +56,56 @@ def sddmm( ...@@ -56,3 +56,56 @@ def sddmm(
# PyTorch's sddmm operator only supports CSR format. # PyTorch's sddmm operator only supports CSR format.
res = torch.sparse.sampled_addmm(A.adj.to_sparse_csr(), mat1, mat2) res = torch.sparse.sampled_addmm(A.adj.to_sparse_csr(), mat1, mat2)
return SparseMatrix(A.row, A.col, res.values(), A.adj.shape) return SparseMatrix(A.row, A.col, res.values(), A.adj.shape)
def mock_bsddmm(
A: SparseMatrix, mat1: torch.Tensor, mat2: torch.Tensor
) -> SparseMatrix:
r"""Batched Sampled-Dense-Dense Matrix Multiplication (SDDMM).
``bsddmm`` conducts `sddmm` for each batch of the two dense matrices
independently.
In particular, :attr:``mat1`` and :attr:``mat2`` can be 2-D, which will be
reshape as `(B, M, 1)` and `(B, 1, K)` in the computation.
Parameters
----------
A : SparseMatrix
Sparse matrix of shape `(M, N)`.
mat1 : Tensor
Dense matrix of shape `(B, M, K)` or `(B, M,)`
mat2 : Tensor
Dense matrix of shape `(B, K, N)` or `(B, K,)`
Returns
-------
SparseMatrix
Sparse matrix of shape `(M, N)` with non-zero values of `B` dimension.
Examples
--------
>>> row = torch.tensor([1, 1, 2])
>>> col = torch.tensor([2, 3, 3])
>>> val = torch.arange(1, 4).float()
>>> A = create_from_coo(row, col, val, (3, 4))
>>> mat1 = torch.randn(2, 3, 5)
>>> mat2 = torch.randn(2, 5, 4)
>>> dgl.mock_sparse.mock_bsddmm(A, mat1, mat2)
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 3, 3]]),
values=tensor([[-0.6765, -0.4017],
[ 3.3290, 6.9016],
[ 4.8184, 5.8882]]),
shape=(3, 4), nnz=3)
"""
batch_mat1 = [mat1[i, ...] for i in range(mat1.shape[0])]
batch_mat2 = [mat2[i, ...] for i in range(mat2.shape[0])]
batch_ret = [sddmm(A, lhs, rhs) for lhs, rhs in zip(batch_mat1, batch_mat2)]
return create_from_coo(
row=A.row,
col=A.col,
val=torch.stack([sp_mat.val for sp_mat in batch_ret], dim=-1),
shape=A.shape,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment