"vscode:/vscode.git/clone" did not exist on "42e8abc34b2409be691e473b80be3ab65f9926a2"
Unverified Commit bff32a09 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Use efficient implementation for Diag @ Sparse and Sparse @ Diag. (#5147)

parent 21c4c29a
......@@ -6,7 +6,7 @@ import torch
from .diag_matrix import diag, DiagMatrix
from .sparse_matrix import SparseMatrix
from .sparse_matrix import SparseMatrix, val_like
__all__ = ["spmm", "bspmm", "spspmm", "mm"]
......@@ -117,6 +117,62 @@ def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix:
return diag(diag_val.to(A1.device), (M, P))
def _sparse_diag_mm(A, D):
"""Internal function for multiplying a sparse matrix by a diagonal matrix.
Parameters
----------
A : SparseMatrix
Matrix of shape (N, M), with values of shape (nnz1)
D : DiagMatrix
Matrix of shape (M, P), with values of shape (nnz2)
Returns
-------
SparseMatrix
SparseMatrix with shape (N, P)
"""
assert (
A.shape[1] == D.shape[0]
), f"The second dimension of SparseMatrix should be equal to the first \
dimension of DiagMatrix in matmul(SparseMatrix, DiagMatrix), but the \
shapes of SparseMatrix and DiagMatrix are {A.shape} and {D.shape} \
respectively."
assert (
D.shape[0] == D.shape[1]
), f"The DiagMatrix should be a square in matmul(SparseMatrix, DiagMatrix) \
but got {D.shape}"
return val_like(A, D.val[A.col] * A.val)
def _diag_sparse_mm(D, A):
"""Internal function for multiplying a diag matrix by a sparse matrix.
Parameters
----------
D : DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1)
A : DiagMatrix
Matrix of shape (M, P), with values of shape (nnz2)
Returns
-------
SparseMatrix
SparseMatrix with shape (N, P)
"""
assert (
D.shape[1] == A.shape[0]
), f"The second dimension of DiagMatrix should be equal to the first \
dimension of SparseMatrix in matmul(DiagMatrix, SparseMatrix), but the \
shapes of DiagMatrix and SparseMatrix are {D.shape} and {A.shape} \
respectively."
assert (
D.shape[0] == D.shape[1]
), f"The DiagMatrix should be a square in matmul(DiagMatrix, SparseMatrix) \
but got {D.shape}"
return val_like(A, D.val[A.row] * A.val)
def spspmm(
A1: Union[SparseMatrix, DiagMatrix], A2: Union[SparseMatrix, DiagMatrix]
) -> Union[SparseMatrix, DiagMatrix]:
......@@ -165,9 +221,9 @@ def spspmm(
if isinstance(A1, DiagMatrix) and isinstance(A2, DiagMatrix):
return _diag_diag_mm(A1, A2)
if isinstance(A1, DiagMatrix):
A1 = A1.as_sparse()
return _diag_sparse_mm(A1, A2)
if isinstance(A2, DiagMatrix):
A2 = A2.as_sparse()
return _sparse_diag_mm(A1, A2)
return SparseMatrix(
torch.ops.dgl_sparse.spspmm(A1.c_sparse_matrix, A2.c_sparse_matrix)
)
......
......@@ -4,7 +4,7 @@ import backend as F
import pytest
import torch
from dgl.sparse import bspmm, from_coo, val_like
from dgl.sparse import bspmm, diag, from_coo, mm, val_like
from .utils import (
clone_detach_and_grad,
......@@ -144,3 +144,71 @@ def test_spspmm_duplicate():
pass
else:
assert False, "Should raise error."
@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("sparse_shape", [(5, 5), (5, 6)])
@pytest.mark.parametrize("nnz", [1, 10])
def test_sparse_diag_mm(create_func, sparse_shape, nnz):
dev = F.ctx()
diag_shape = sparse_shape[1], sparse_shape[1]
A = create_func(sparse_shape, nnz, dev)
diag_val = torch.randn(sparse_shape[1], device=dev, requires_grad=True)
D = diag(diag_val, diag_shape)
# (TODO) Need to use dgl.sparse.matmul after rename mm to matmul
B = mm(A, D)
grad = torch.randn_like(B.val)
B.val.backward(grad)
torch_A = sparse_matrix_to_torch_sparse(A)
torch_D = sparse_matrix_to_torch_sparse(D.as_sparse())
torch_B = torch.sparse.mm(torch_A, torch_D)
torch_B_grad = sparse_matrix_to_torch_sparse(B, grad)
torch_B.backward(torch_B_grad)
with torch.no_grad():
assert torch.allclose(B.dense(), torch_B.to_dense(), atol=1e-05)
assert torch.allclose(
val_like(A, A.val.grad).dense(),
torch_A.grad.to_dense(),
atol=1e-05,
)
assert torch.allclose(
diag(D.val.grad, D.shape).dense(),
torch_D.grad.to_dense(),
atol=1e-05,
)
@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("sparse_shape", [(5, 5), (5, 6)])
@pytest.mark.parametrize("nnz", [1, 10])
def test_diag_sparse_mm(create_func, sparse_shape, nnz):
dev = F.ctx()
diag_shape = sparse_shape[0], sparse_shape[0]
A = create_func(sparse_shape, nnz, dev)
diag_val = torch.randn(sparse_shape[0], device=dev, requires_grad=True)
D = diag(diag_val, diag_shape)
# (TODO) Need to use dgl.sparse.matmul after rename mm to matmul
B = mm(D, A)
grad = torch.randn_like(B.val)
B.val.backward(grad)
torch_A = sparse_matrix_to_torch_sparse(A)
torch_D = sparse_matrix_to_torch_sparse(D.as_sparse())
torch_B = torch.sparse.mm(torch_D, torch_A)
torch_B_grad = sparse_matrix_to_torch_sparse(B, grad)
torch_B.backward(torch_B_grad)
with torch.no_grad():
assert torch.allclose(B.dense(), torch_B.to_dense(), atol=1e-05)
assert torch.allclose(
val_like(A, A.val.grad).dense(),
torch_A.grad.to_dense(),
atol=1e-05,
)
assert torch.allclose(
diag(D.val.grad, D.shape).dense(),
torch_D.grad.to_dense(),
atol=1e-05,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment