Unverified Commit 7c465d20 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Support spspdiv (#5541)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent bb1f8850
......@@ -14,20 +14,20 @@ namespace sparse {
/**
* @brief Adds two sparse matrices possibly with different sparsities.
*
* @param A SparseMatrix
* @param B SparseMatrix
* @param lhs_mat SparseMatrix
* @param rhs_mat SparseMatrix
*
* @return SparseMatrix
*/
c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B);
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat);
/**
* @brief Multiplies two sparse matrices possibly with different sparsities.
*
* @param A SparseMatrix
* @param B SparseMatrix
* @param lhs_mat SparseMatrix
* @param rhs_mat SparseMatrix
*
* @return SparseMatrix
*/
......@@ -35,6 +35,18 @@ c10::intrusive_ptr<SparseMatrix> SpSpMul(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat);
/**
* @brief Divides two sparse matrices with the same sparsity.
*
* @param lhs_mat SparseMatrix
* @param rhs_mat SparseMatrix
*
* @return SparseMatrix
*/
c10::intrusive_ptr<SparseMatrix> SpSpDiv(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat);
} // namespace sparse
} // namespace dgl
......
......@@ -14,6 +14,7 @@
#include <torch/script.h>
#include <memory>
#include <utility>
namespace dgl {
namespace sparse {
......@@ -113,6 +114,13 @@ std::shared_ptr<CSR> DiagToCSC(
/** @brief COO transposition. */
std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo);
/**
* @brief Sort the COO matrix by row and column indices.
* @return A pair of the sorted COO matrix and the permutation indices.
*/
std::pair<std::shared_ptr<COO>, torch::Tensor> COOSort(
const std::shared_ptr<COO>& coo);
} // namespace sparse
} // namespace dgl
......
......@@ -19,17 +19,18 @@ namespace sparse {
using namespace torch::autograd;
c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B) {
ElementwiseOpSanityCheck(A, B);
if (A->HasDiag() && B->HasDiag()) {
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
ElementwiseOpSanityCheck(lhs_mat, rhs_mat);
if (lhs_mat->HasDiag() && rhs_mat->HasDiag()) {
return SparseMatrix::FromDiagPointer(
A->DiagPtr(), A->value() + B->value(), A->shape());
lhs_mat->DiagPtr(), lhs_mat->value() + rhs_mat->value(),
lhs_mat->shape());
}
auto torch_A = COOToTorchCOO(A->COOPtr(), A->value());
auto torch_B = COOToTorchCOO(B->COOPtr(), B->value());
auto sum = (torch_A + torch_B).coalesce();
return SparseMatrix::FromCOO(sum.indices(), sum.values(), A->shape());
auto torch_lhs = COOToTorchCOO(lhs_mat->COOPtr(), lhs_mat->value());
auto torch_rhs = COOToTorchCOO(rhs_mat->COOPtr(), rhs_mat->value());
auto sum = (torch_lhs + torch_rhs).coalesce();
return SparseMatrix::FromCOO(sum.indices(), sum.values(), lhs_mat->shape());
}
class SpSpMulAutoGrad : public Function<SpSpMulAutoGrad> {
......@@ -117,5 +118,34 @@ c10::intrusive_ptr<SparseMatrix> SpSpMul(
return SparseMatrix::FromCOO(indices, val, lhs_mat->shape());
}
c10::intrusive_ptr<SparseMatrix> SpSpDiv(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
ElementwiseOpSanityCheck(lhs_mat, rhs_mat);
if (lhs_mat->HasDiag() && rhs_mat->HasDiag()) {
return SparseMatrix::FromDiagPointer(
lhs_mat->DiagPtr(), lhs_mat->value() / rhs_mat->value(),
lhs_mat->shape());
}
std::shared_ptr<COO> sorted_lhs, sorted_rhs;
torch::Tensor lhs_sorted_perm, rhs_sorted_perm;
std::tie(sorted_lhs, lhs_sorted_perm) = COOSort(lhs_mat->COOPtr());
std::tie(sorted_rhs, rhs_sorted_perm) = COOSort(rhs_mat->COOPtr());
TORCH_CHECK(
!lhs_mat->HasDuplicate() && !rhs_mat->HasDuplicate(),
"Only support SpSpDiv on sparse matrices without duplicate values")
TORCH_CHECK(
torch::equal(sorted_lhs->indices, sorted_rhs->indices),
"Cannot divide two COO matrices with different sparsities.");
// This is to make sure the return matrix is in the same order as the lhs_mat
auto lhs_sorted_rperm = lhs_sorted_perm.argsort();
auto rhs_perm_on_lhs = rhs_sorted_perm.index_select(0, lhs_sorted_rperm);
auto lhs_value = lhs_mat->value();
auto rhs_value = rhs_mat->value().index_select(0, rhs_perm_on_lhs);
auto ret_val = lhs_value / rhs_value;
return SparseMatrix::FromCOOPointer(
lhs_mat->COOPtr(), ret_val, lhs_mat->shape());
}
} // namespace sparse
} // namespace dgl
......@@ -40,6 +40,7 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("from_diag", &SparseMatrix::FromDiag)
.def("spsp_add", &SpSpAdd)
.def("spsp_mul", &SpSpMul)
.def("spsp_div", &SpSpDiv)
.def("reduce", &Reduce)
.def("sum", &ReduceSum)
.def("smean", &ReduceMean)
......
......@@ -140,5 +140,17 @@ std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo) {
return COOFromOldDGLCOO(dgl_coo_tr);
}
std::pair<std::shared_ptr<COO>, torch::Tensor> COOSort(
const std::shared_ptr<COO>& coo) {
auto encoded_coo =
coo->indices.index({0}) * coo->num_cols + coo->indices.index({1});
torch::Tensor sorted, perm;
std::tie(sorted, perm) = encoded_coo.sort();
auto sorted_coo = std::make_shared<COO>(
COO{coo->num_rows, coo->num_cols, coo->indices.index_select(1, perm),
true, true});
return {sorted_coo, perm};
}
} // namespace sparse
} // namespace dgl
......@@ -3,7 +3,7 @@ from typing import Union
import torch
from .sparse_matrix import diag, SparseMatrix, val_like
from .sparse_matrix import SparseMatrix, val_like
from .utils import is_scalar, Scalar
......@@ -21,6 +21,13 @@ def spsp_mul(A, B):
)
def spsp_div(A, B):
"""Invoke C++ sparse library for division"""
return SparseMatrix(
torch.ops.dgl_sparse.spsp_div(A.c_sparse_matrix, B.c_sparse_matrix)
)
def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
"""Elementwise addition
......@@ -141,8 +148,9 @@ def sp_mul(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
def sp_div(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
"""Elementwise division
If :attr:`B` is a sparse matrix, both :attr:`A` and :attr:`B` must be
diagonal matrices.
If :attr:`B` is a sparse matrix, both :attr:`A` and :attr:`B` must have the
same sparsity. And the returned matrix has the same order of non-zero
entries as :attr:`A`.
Parameters
----------
......@@ -169,15 +177,7 @@ def sp_div(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
"""
if is_scalar(B):
return val_like(A, A.val / B)
if A.is_diag() and B.is_diag():
assert A.shape == B.shape, (
f"The shape of diagonal matrix A {A.shape} and B {B.shape} must"
f"match for elementwise division."
)
return diag(A.val / B.val, A.shape)
# Python falls back to B.__rtruediv__(A) then TypeError when NotImplemented
# is returned.
return NotImplemented
return spsp_div(A, B)
def sp_power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix:
......
......@@ -225,7 +225,7 @@ def test_sub_sparse_diag(val_shape):
assert torch.allclose(dense_diff, -diff4)
@pytest.mark.parametrize("op", ["truediv", "pow"])
@pytest.mark.parametrize("op", ["pow"])
def test_error_op_sparse_diag(op):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
......
......@@ -4,7 +4,7 @@ import backend as F
import pytest
import torch
from dgl.sparse import from_coo, mul, power, val_like
from dgl.sparse import div, from_coo, mul, power, spmatrix, val_like
from .utils import (
rand_coo,
......@@ -134,3 +134,24 @@ def test_spspmul(create_func1, create_func2, shape, nnz1, nnz2, nz_dim):
assert torch.allclose(
val_like(B, B.val.grad).to_dense(), DB.grad, atol=1e-05
)
@pytest.mark.parametrize(
"create_func", [rand_coo, rand_csr, rand_csc, rand_diag]
)
@pytest.mark.parametrize("shape", [(5, 5), (5, 3)])
@pytest.mark.parametrize("nnz", [1, 14])
@pytest.mark.parametrize("nz_dim", [None, 3])
def test_spspdiv(create_func, nnz, shape, nz_dim):
dev = F.ctx()
A = create_func(shape, nnz, dev, nz_dim)
perm = torch.randperm(A.nnz, device=dev)
rperm = torch.argsort(perm)
B = spmatrix(A.indices()[:, perm], A.val[perm], A.shape)
C = div(A, B)
assert not C.has_duplicate()
assert torch.allclose(C.val, A.val / B.val[rperm], atol=1e-05)
assert torch.allclose(C.indices(), A.indices(), atol=1e-05)
# No need to test backward here, since it is handled by Pytorch
......@@ -35,9 +35,9 @@ def test_softmax(val_D, csr, dim):
g = dgl.graph((row, col), num_nodes=max(A.shape))
val_g = val.clone().requires_grad_()
score = dgl.nn.functional.edge_softmax(g, val_g)
assert torch.allclose(A_max.val, score)
assert torch.allclose(A_max.val, score, atol=1e-05)
grad = torch.randn_like(score).to(dev)
A_max.val.backward(grad)
score.backward(grad)
assert torch.allclose(A.val.grad, val_g.grad)
assert torch.allclose(A.val.grad, val_g.grad, atol=1e-05)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment