Unverified Commit 89309888 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Support spspmul (#5543)

parent 4864a9f9
...@@ -11,10 +11,8 @@ ...@@ -11,10 +11,8 @@
namespace dgl { namespace dgl {
namespace sparse { namespace sparse {
// TODO(zhenkun): support addition of matrices with different sparsity.
/** /**
* @brief Adds two sparse matrices. Currently does not support two matrices with * @brief Adds two sparse matrices possibly with different sparsities.
* different sparsity.
* *
* @param A SparseMatrix * @param A SparseMatrix
* @param B SparseMatrix * @param B SparseMatrix
...@@ -25,6 +23,18 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd( ...@@ -25,6 +23,18 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A, const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B); const c10::intrusive_ptr<SparseMatrix>& B);
/**
* @brief Multiplies two sparse matrices possibly with different sparsities.
*
* @param A SparseMatrix
* @param B SparseMatrix
*
* @return SparseMatrix
*/
c10::intrusive_ptr<SparseMatrix> SpSpMul(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat);
} // namespace sparse } // namespace sparse
} // namespace dgl } // namespace dgl
......
/**
* Copyright (c) 2023 by Contributors
* @file sparse/matrix_ops.h
* @brief DGL C++ sparse matrix operators.
*/
#ifndef SPARSE_MATRIX_OPS_H_
#define SPARSE_MATRIX_OPS_H_
#include <sparse/sparse_format.h>
#include <tuple>
namespace dgl {
namespace sparse {
/**
* @brief Compute the intersection of two COO matrices. Return the intersection
* matrix, and the indices of the intersection in the left-hand-side and
* right-hand-side matrices.
*
* @param lhs The left-hand-side COO matrix.
* @param rhs The right-hand-side COO matrix.
*
* @return A tuple of COO matrix, lhs indices, and rhs indices.
*/
std::tuple<std::shared_ptr<COO>, torch::Tensor, torch::Tensor> COOIntersection(
const std::shared_ptr<COO>& lhs, const std::shared_ptr<COO>& rhs);
} // namespace sparse
} // namespace dgl
#endif // SPARSE_MATRIX_OPS_H_
...@@ -3,11 +3,9 @@ ...@@ -3,11 +3,9 @@
* @file elementwise_op.cc * @file elementwise_op.cc
* @brief DGL C++ sparse elementwise operator implementation. * @brief DGL C++ sparse elementwise operator implementation.
*/ */
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on
#include <sparse/elementwise_op.h> #include <sparse/elementwise_op.h>
#include <sparse/matrix_ops.h>
#include <sparse/sparse_matrix.h> #include <sparse/sparse_matrix.h>
#include <torch/script.h> #include <torch/script.h>
...@@ -18,6 +16,8 @@ ...@@ -18,6 +16,8 @@
namespace dgl { namespace dgl {
namespace sparse { namespace sparse {
using namespace torch::autograd;
c10::intrusive_ptr<SparseMatrix> SpSpAdd( c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A, const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B) { const c10::intrusive_ptr<SparseMatrix>& B) {
...@@ -32,5 +32,90 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd( ...@@ -32,5 +32,90 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd(
return SparseMatrix::FromCOO(sum.indices(), sum.values(), A->shape()); return SparseMatrix::FromCOO(sum.indices(), sum.values(), A->shape());
} }
class SpSpMulAutoGrad : public Function<SpSpMulAutoGrad> {
public:
static variable_list forward(
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,
torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,
torch::Tensor rhs_val);
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};
variable_list SpSpMulAutoGrad::forward(
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,
torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,
torch::Tensor rhs_val) {
std::shared_ptr<COO> intersection;
torch::Tensor lhs_indices, rhs_indices;
std::tie(intersection, lhs_indices, rhs_indices) =
COOIntersection(lhs_mat->COOPtr(), rhs_mat->COOPtr());
auto lhs_intersect_val = lhs_val.index_select(0, lhs_indices);
auto rhs_intersect_val = rhs_val.index_select(0, rhs_indices);
auto ret_val = lhs_intersect_val * rhs_intersect_val;
auto ret_mat =
SparseMatrix::FromCOOPointer(intersection, ret_val, lhs_mat->shape());
ctx->saved_data["lhs_require_grad"] = lhs_val.requires_grad();
ctx->saved_data["rhs_require_grad"] = rhs_val.requires_grad();
if (lhs_val.requires_grad()) {
ctx->saved_data["lhs_val_shape"] = lhs_val.sizes().vec();
ctx->saved_data["rhs_intersect_lhs"] =
SparseMatrix::ValLike(ret_mat, rhs_intersect_val);
ctx->saved_data["lhs_indices"] = lhs_indices;
}
if (rhs_val.requires_grad()) {
ctx->saved_data["rhs_val_shape"] = rhs_val.sizes().vec();
ctx->saved_data["lhs_intersect_rhs"] =
SparseMatrix::ValLike(ret_mat, lhs_intersect_val);
ctx->saved_data["rhs_indices"] = rhs_indices;
}
return {intersection->indices, ret_val};
}
tensor_list SpSpMulAutoGrad::backward(
AutogradContext* ctx, tensor_list grad_outputs) {
torch::Tensor lhs_val_grad, rhs_val_grad;
auto output_grad = grad_outputs[1];
if (ctx->saved_data["lhs_require_grad"].toBool()) {
auto rhs_intersect_lhs =
ctx->saved_data["rhs_intersect_lhs"].toCustomClass<SparseMatrix>();
const auto& lhs_val_shape = ctx->saved_data["lhs_val_shape"].toIntVector();
auto lhs_indices = ctx->saved_data["lhs_indices"].toTensor();
lhs_val_grad = torch::zeros(lhs_val_shape, output_grad.options());
auto intersect_grad = rhs_intersect_lhs->value() * output_grad;
lhs_val_grad.index_put_({lhs_indices}, intersect_grad);
}
if (ctx->saved_data["rhs_require_grad"].toBool()) {
auto lhs_intersect_rhs =
ctx->saved_data["lhs_intersect_rhs"].toCustomClass<SparseMatrix>();
const auto& rhs_val_shape = ctx->saved_data["rhs_val_shape"].toIntVector();
auto rhs_indices = ctx->saved_data["rhs_indices"].toTensor();
rhs_val_grad = torch::zeros(rhs_val_shape, output_grad.options());
auto intersect_grad = lhs_intersect_rhs->value() * output_grad;
rhs_val_grad.index_put_({rhs_indices}, intersect_grad);
}
return {torch::Tensor(), lhs_val_grad, torch::Tensor(), rhs_val_grad};
}
c10::intrusive_ptr<SparseMatrix> SpSpMul(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
ElementwiseOpSanityCheck(lhs_mat, rhs_mat);
if (lhs_mat->HasDiag() && rhs_mat->HasDiag()) {
return SparseMatrix::FromDiagPointer(
lhs_mat->DiagPtr(), lhs_mat->value() * rhs_mat->value(),
lhs_mat->shape());
}
TORCH_CHECK(
!lhs_mat->HasDuplicate() && !rhs_mat->HasDuplicate(),
"Only support SpSpMul on sparse matrices without duplicate values")
auto results = SpSpMulAutoGrad::apply(
lhs_mat, lhs_mat->value(), rhs_mat, rhs_mat->value());
const auto& indices = results[0];
const auto& val = results[1];
return SparseMatrix::FromCOO(indices, val, lhs_mat->shape());
}
} // namespace sparse } // namespace sparse
} // namespace dgl } // namespace dgl
/**
* Copyright (c) 2023 by Contributors
* @file matrix_ops.cc
* @brief DGL C++ matrix operators.
*/
#include <sparse/matrix_ops.h>
#include <torch/script.h>
namespace dgl {
namespace sparse {
/**
* @brief Compute the intersection of two COO matrices. Return the intersection
* COO matrix, and the indices of the intersection in the left-hand-side and
* right-hand-side COO matrices.
*
* @param lhs The left-hand-side COO matrix.
* @param rhs The right-hand-side COO matrix.
*
* @return A tuple of COO matrix, lhs indices, and rhs indices.
*/
std::tuple<std::shared_ptr<COO>, torch::Tensor, torch::Tensor> COOIntersection(
const std::shared_ptr<COO>& lhs, const std::shared_ptr<COO>& rhs) {
// 1. Encode the two COO matrices into arrays of integers.
auto lhs_arr =
lhs->indices.index({0}) * lhs->num_cols + lhs->indices.index({1});
auto rhs_arr =
rhs->indices.index({0}) * rhs->num_cols + rhs->indices.index({1});
// 2. Concatenate the two arrays.
auto arr = torch::cat({lhs_arr, rhs_arr});
// 3. Unique the concatenated array.
torch::Tensor unique, inverse, counts;
std::tie(unique, inverse, counts) =
torch::unique_dim(arr, 0, false, true, true);
// 4. Find the indices of the counts greater than 1 in the unique array.
auto mask = counts > 1;
// 5. Map the inverse array to the original array to generate indices.
auto lhs_inverse = inverse.slice(0, 0, lhs_arr.numel());
auto rhs_inverse = inverse.slice(0, lhs_arr.numel(), arr.numel());
auto map_to_original = torch::empty_like(unique);
map_to_original.index_put_(
{lhs_inverse},
torch::arange(lhs_inverse.numel(), map_to_original.options()));
auto lhs_indices = map_to_original.index({mask});
map_to_original.index_put_(
{rhs_inverse},
torch::arange(rhs_inverse.numel(), map_to_original.options()));
auto rhs_indices = map_to_original.index({mask});
// 6. Decode the indices to get the intersection COO matrix.
auto ret_arr = unique.index({mask});
auto ret_indices = torch::stack(
{ret_arr.floor_divide(lhs->num_cols), ret_arr % lhs->num_cols}, 0);
auto ret_coo = std::make_shared<COO>(
COO{lhs->num_rows, lhs->num_cols, ret_indices, false, false});
return {ret_coo, lhs_indices, rhs_indices};
}
} // namespace sparse
} // namespace dgl
...@@ -39,6 +39,7 @@ TORCH_LIBRARY(dgl_sparse, m) { ...@@ -39,6 +39,7 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("from_csc", &SparseMatrix::FromCSC) .def("from_csc", &SparseMatrix::FromCSC)
.def("from_diag", &SparseMatrix::FromDiag) .def("from_diag", &SparseMatrix::FromDiag)
.def("spsp_add", &SpSpAdd) .def("spsp_add", &SpSpAdd)
.def("spsp_mul", &SpSpMul)
.def("reduce", &Reduce) .def("reduce", &Reduce)
.def("sum", &ReduceSum) .def("sum", &ReduceSum)
.def("smean", &ReduceMean) .def("smean", &ReduceMean)
......
...@@ -23,6 +23,9 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::Coalesce() { ...@@ -23,6 +23,9 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::Coalesce() {
bool SparseMatrix::HasDuplicate() { bool SparseMatrix::HasDuplicate() {
aten::CSRMatrix dgl_csr; aten::CSRMatrix dgl_csr;
if (HasDiag()) {
return false;
}
// The format for calculation will be chosen in the following order: CSR, // The format for calculation will be chosen in the following order: CSR,
// CSC. CSR is created if the sparse matrix only has CSC format. // CSC. CSR is created if the sparse matrix only has CSC format.
if (HasCSR() || !HasCSC()) { if (HasCSR() || !HasCSC()) {
......
...@@ -14,6 +14,13 @@ def spsp_add(A, B): ...@@ -14,6 +14,13 @@ def spsp_add(A, B):
) )
def spsp_mul(A, B):
"""Invoke C++ sparse library for multiplication"""
return SparseMatrix(
torch.ops.dgl_sparse.spsp_mul(A.c_sparse_matrix, B.c_sparse_matrix)
)
def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix: def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
"""Elementwise addition """Elementwise addition
...@@ -83,8 +90,8 @@ def sp_sub(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix: ...@@ -83,8 +90,8 @@ def sp_sub(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
def sp_mul(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix: def sp_mul(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
"""Elementwise multiplication """Elementwise multiplication
If :attr:`B` is a sparse matrix, both :attr:`A` and :attr:`B` must be Note that if both :attr:`A` and :attr:`B` are sparse matrices, both of them
diagonal matrices. need to be diagonal or on CPU.
Parameters Parameters
---------- ----------
...@@ -116,20 +123,19 @@ def sp_mul(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix: ...@@ -116,20 +123,19 @@ def sp_mul(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
[0, 3, 2]]), [0, 3, 2]]),
values=tensor([2, 4, 6]), values=tensor([2, 4, 6]),
shape=(3, 4), nnz=3) shape=(3, 4), nnz=3)
>>> indices2 = torch.tensor([[2, 0, 1], [0, 3, 2]])
>>> val2 = torch.tensor([3, 2, 1])
>>> B = dglsp.spmatrix(indices2, val2, shape=(3, 4))
>>> A * B
SparseMatrix(indices=tensor([[0],
[3]]),
values=tensor([4]),
shape=(3, 4), nnz=1)
""" """
if is_scalar(B): if is_scalar(B):
return val_like(A, A.val * B) return val_like(A, A.val * B)
if A.is_diag() and B.is_diag(): return spsp_mul(A, B)
assert A.shape == B.shape, (
f"The shape of diagonal matrix A {A.shape} and B {B.shape} must"
f"match for elementwise multiplication."
)
return diag(A.val * B.val, A.shape)
# Python falls back to B.__rmul__(A) then TypeError when NotImplemented is
# returned.
# So this also handles the case of scalar * SparseMatrix since we set
# SparseMatrix.__rmul__ to be the same as SparseMatrix.__mul__.
return NotImplemented
def sp_div(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix: def sp_div(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
......
...@@ -225,7 +225,7 @@ def test_sub_sparse_diag(val_shape): ...@@ -225,7 +225,7 @@ def test_sub_sparse_diag(val_shape):
assert torch.allclose(dense_diff, -diff4) assert torch.allclose(dense_diff, -diff4)
@pytest.mark.parametrize("op", ["mul", "truediv", "pow"]) @pytest.mark.parametrize("op", ["truediv", "pow"])
def test_error_op_sparse_diag(op): def test_error_op_sparse_diag(op):
ctx = F.ctx() ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx) row = torch.tensor([1, 0, 2]).to(ctx)
......
...@@ -4,7 +4,15 @@ import backend as F ...@@ -4,7 +4,15 @@ import backend as F
import pytest import pytest
import torch import torch
from dgl.sparse import from_coo, power from dgl.sparse import from_coo, mul, power, val_like
from .utils import (
rand_coo,
rand_csc,
rand_csr,
rand_diag,
sparse_matrix_to_dense,
)
def all_close_sparse(A, row, col, val, shape): def all_close_sparse(A, row, col, val, shape):
...@@ -91,3 +99,38 @@ def test_error_op_scalar(op, v_scalar): ...@@ -91,3 +99,38 @@ def test_error_op_scalar(op, v_scalar):
A - v_scalar A - v_scalar
with pytest.raises(TypeError): with pytest.raises(TypeError):
v_scalar - A v_scalar - A
@pytest.mark.parametrize(
"create_func1", [rand_coo, rand_csr, rand_csc, rand_diag]
)
@pytest.mark.parametrize(
"create_func2", [rand_coo, rand_csr, rand_csc, rand_diag]
)
@pytest.mark.parametrize("shape", [(5, 5), (5, 3)])
@pytest.mark.parametrize("nnz1", [5, 15])
@pytest.mark.parametrize("nnz2", [1, 14])
@pytest.mark.parametrize("nz_dim", [None, 3])
def test_spspmul(create_func1, create_func2, shape, nnz1, nnz2, nz_dim):
dev = F.ctx()
A = create_func1(shape, nnz1, dev, nz_dim)
B = create_func2(shape, nnz2, dev, nz_dim)
C = mul(A, B)
assert not C.has_duplicate()
DA = sparse_matrix_to_dense(A)
DB = sparse_matrix_to_dense(B)
DC = DA * DB
grad = torch.rand_like(C.val)
C.val.backward(grad)
DC_grad = sparse_matrix_to_dense(val_like(C, grad))
DC.backward(DC_grad)
assert torch.allclose(sparse_matrix_to_dense(C), DC, atol=1e-05)
assert torch.allclose(
val_like(A, A.val.grad).to_dense(), DA.grad, atol=1e-05
)
assert torch.allclose(
val_like(B, B.val.grad).to_dense(), DB.grad, atol=1e-05
)
import numpy as np import numpy as np
import torch import torch
from dgl.sparse import from_csc, from_csr, SparseMatrix, spmatrix from dgl.sparse import diag, from_csc, from_csr, SparseMatrix, spmatrix
np.random.seed(42) np.random.seed(42)
torch.random.manual_seed(42) torch.random.manual_seed(42)
...@@ -83,6 +83,15 @@ def rand_csc(shape, nnz, dev, nz_dim=None): ...@@ -83,6 +83,15 @@ def rand_csc(shape, nnz, dev, nz_dim=None):
return from_csc(indptr, indices, val, shape=shape) return from_csc(indptr, indices, val, shape=shape)
def rand_diag(shape, nnz, dev, nz_dim=None):
nnz = min(shape)
if nz_dim is None:
val = torch.randn(nnz, device=dev, requires_grad=True)
else:
val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)
return diag(val, shape)
def rand_coo_uncoalesced(shape, nnz, dev): def rand_coo_uncoalesced(shape, nnz, dev):
# Create a sparse matrix with possible duplicate entries. # Create a sparse matrix with possible duplicate entries.
row = torch.randint(shape[0], (nnz,), device=dev) row = torch.randint(shape[0], (nnz,), device=dev)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment