"src/vscode:/vscode.git/clone" did not exist on "0d687968dd29bb403c5cac7004d3a0332a127ac6"
Unverified Commit fa05ccb9 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

Add `sddmm` operator in mock_sparse library (#4579)



* sddmm init

* SDDMM with N-D nonzero values

* drop support for vector shaped non zero elements

* address comments

* skip cpu test

* skip GPU test too
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent ec5e7bc5
......@@ -48,3 +48,4 @@ Operators
spspmm
bspmm
bspspmm
sddmm
......@@ -2,6 +2,7 @@
from .diag_matrix import *
from .sp_matrix import *
from .elementwise_op_sp import *
from .sddmm import *
from .reduction import * # pylint: disable=W0622
from .unary_diag import *
from .unary_sp import *
......
"""Sampled Dense-Dense Matrix Multiplication (SDDMM) operator module."""
import torch
from .sp_matrix import SparseMatrix
__all__ = ["sddmm"]
def sddmm(
A: SparseMatrix, mat1: torch.Tensor, mat2: torch.Tensor
) -> SparseMatrix:
r"""Sampled-Dense-Dense Matrix Multiplication (SDDMM).
``sddmm`` multiplies two dense matrices :attr:``mat1`` and :attr:``mat2``
at the nonzero locations of sparse matrix :attr:``A``. Values of :attr:``A``
is added to the resulting matrix.
Mathematically ``sddmm`` is formulated as:
.. math::
out = (mat1 @ mat2) * spy(A) + A
Parameters
----------
A : SparseMatrix
Sparse matrix of shape `(M, N)`.
mat1 : Tensor
Dense matrix of shape `(M, K)`
mat2 : Tensor
Dense matrix of shape `(K, N)`
Returns
-------
SparseMatrix
Sparse matrix of shape `(M, N)`.
Examples
--------
>>> row = torch.Tensor([1, 1, 2])
>>> col = torch.Tensor([2, 3, 3])
>>> val = torch.arange(1, 4).float()
>>> A = SparseMatrix(row, col, val, (3, 4))
>>> mat1 = torch.randn(3, 5)
>>> mat2 = torch.randn(5, 4)
>>> dgl.mock_sparse.sddmm(A, mat1, mat2)
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 3, 3]]),
values=tensor([1.8035, 2.3375, 3.1255]),
shape=(3, 4), nnz=3)
"""
assert A.val.dim() == 1, (
f"Nonzero elements have values of shape ({A.val.shape[1]}). Expects "
"scalar values. "
)
# PyTorch's sddmm operator only supports CSR format.
res = torch.sparse.sampled_addmm(A.adj.to_sparse_csr(), mat1, mat2)
return SparseMatrix(A.row, A.col, res.values(), A.adj.shape)
import unittest
import backend as F
import dgl
import pytest
import torch
from dgl.mock_sparse import SparseMatrix
parametrize_idtype = pytest.mark.parametrize(
"idtype", [torch.int32, torch.int64]
)
parametrize_dtype = pytest.mark.parametrize(
"dtype", [torch.float32, torch.float64]
)
def all_close_sparse(A, B):
assert torch.allclose(A.indices(), B.indices())
assert torch.allclose(A.values(), B.values())
assert A.shape == B.shape
# TODO (Israt): Implement sddmm. Do not rely on PyTorch.
@unittest.skipIf(
F._default_context_str == "cpu",
reason="sddmm uses sampled_addmm from pytorch which supports only CUDA",
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="sddmm uses sampled_addmm from pytorch which requires pytorch "
"1.12 or higher. Current CI doesn't support that.",
)
@parametrize_idtype
@parametrize_dtype
def test_sddmm(idtype, dtype):
row = torch.tensor([1, 0, 2, 9, 1])
col = torch.tensor([0, 49, 2, 1, 7])
val = torch.arange(1, 6).float()
A = SparseMatrix(row, col, val, shape=(10, 50))
matB = torch.rand(10, 5)
matC = torch.rand(5, 50)
dgl_result = dgl.mock_sparse.sddmm(A, matB, matC)
th_result = torch.sparse.sampled_addmm(A.adj.to_sparse_csr(), matB, matC)
all_close_sparse(dgl_result.adj, th_result.to_sparse_coo())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment