Unverified Commit a03dec05 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Support Diag sparse format in C++ (#5432)

* [Sparse] Support Diag sparse format in C++

* update

* Update
parent b7ce4b6a
...@@ -19,7 +19,7 @@ namespace dgl { ...@@ -19,7 +19,7 @@ namespace dgl {
namespace sparse { namespace sparse {
/** @brief SparseFormat enumeration. */ /** @brief SparseFormat enumeration. */
enum SparseFormat { kCOO, kCSR, kCSC }; enum SparseFormat { kCOO, kCSR, kCSC, kDiag };
/** @brief COO sparse structure. */ /** @brief COO sparse structure. */
struct COO { struct COO {
...@@ -50,6 +50,11 @@ struct CSR { ...@@ -50,6 +50,11 @@ struct CSR {
bool sorted = false; bool sorted = false;
}; };
struct Diag {
/** @brief The dense shape of the matrix. */
int64_t num_rows = 0, num_cols = 0;
};
/** @brief Convert an old DGL COO format to a COO in the sparse library. */ /** @brief Convert an old DGL COO format to a COO in the sparse library. */
std::shared_ptr<COO> COOFromOldDGLCOO(const aten::COOMatrix& dgl_coo); std::shared_ptr<COO> COOFromOldDGLCOO(const aten::COOMatrix& dgl_coo);
...@@ -90,6 +95,21 @@ std::shared_ptr<CSR> COOToCSC(const std::shared_ptr<COO>& coo); ...@@ -90,6 +95,21 @@ std::shared_ptr<CSR> COOToCSC(const std::shared_ptr<COO>& coo);
/** @brief Convert a CSR format to CSC format. */ /** @brief Convert a CSR format to CSC format. */
std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr); std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr);
/** @brief Convert a Diag format to COO format. */
std::shared_ptr<COO> DiagToCOO(
const std::shared_ptr<Diag>& diag,
const c10::TensorOptions& indices_options);
/** @brief Convert a Diag format to CSR format. */
std::shared_ptr<CSR> DiagToCSR(
const std::shared_ptr<Diag>& diag,
const c10::TensorOptions& indices_options);
/** @brief Convert a Diag format to CSC format. */
std::shared_ptr<CSR> DiagToCSC(
const std::shared_ptr<Diag>& diag,
const c10::TensorOptions& indices_options);
/** @brief COO transposition. */ /** @brief COO transposition. */
std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo); std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo);
......
...@@ -38,8 +38,8 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -38,8 +38,8 @@ class SparseMatrix : public torch::CustomClassHolder {
*/ */
SparseMatrix( SparseMatrix(
const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr, const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
const std::shared_ptr<CSR>& csc, torch::Tensor value, const std::shared_ptr<CSR>& csc, const std::shared_ptr<Diag>& diag,
const std::vector<int64_t>& shape); torch::Tensor value, const std::vector<int64_t>& shape);
/** /**
* @brief Construct a SparseMatrix from a COO format. * @brief Construct a SparseMatrix from a COO format.
...@@ -77,6 +77,18 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -77,6 +77,18 @@ class SparseMatrix : public torch::CustomClassHolder {
const std::shared_ptr<CSR>& csc, torch::Tensor value, const std::shared_ptr<CSR>& csc, torch::Tensor value,
const std::vector<int64_t>& shape); const std::vector<int64_t>& shape);
/**
* @brief Construct a SparseMatrix from a Diag format.
* @param diag The Diag format
* @param value Values of the sparse matrix
* @param shape Shape of the sparse matrix
*
* @return SparseMatrix
*/
static c10::intrusive_ptr<SparseMatrix> FromDiagPointer(
const std::shared_ptr<Diag>& diag, torch::Tensor value,
const std::vector<int64_t>& shape);
/** /**
* @brief Create a SparseMatrix from tensors in COO format. * @brief Create a SparseMatrix from tensors in COO format.
* @param indices COO coordinates with shape (2, nnz). * @param indices COO coordinates with shape (2, nnz).
...@@ -115,6 +127,16 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -115,6 +127,16 @@ class SparseMatrix : public torch::CustomClassHolder {
torch::Tensor indptr, torch::Tensor indices, torch::Tensor value, torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
const std::vector<int64_t>& shape); const std::vector<int64_t>& shape);
/**
* @brief Create a SparseMatrix with Diag format.
* @param value Values of the sparse matrix
* @param shape Shape of the sparse matrix
*
* @return SparseMatrix
*/
static c10::intrusive_ptr<SparseMatrix> FromDiag(
torch::Tensor value, const std::vector<int64_t>& shape);
/** /**
* @brief Create a SparseMatrix from a SparseMatrix using new values. * @brief Create a SparseMatrix from a SparseMatrix using new values.
* @param mat An existing sparse matrix * @param mat An existing sparse matrix
...@@ -142,6 +164,11 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -142,6 +164,11 @@ class SparseMatrix : public torch::CustomClassHolder {
std::shared_ptr<CSR> CSRPtr(); std::shared_ptr<CSR> CSRPtr();
/** @return CSC of the sparse matrix. The CSC is created if not exists. */ /** @return CSC of the sparse matrix. The CSC is created if not exists. */
std::shared_ptr<CSR> CSCPtr(); std::shared_ptr<CSR> CSCPtr();
/**
* @return Diagonal format of the sparse matrix. An error will be raised if
* it does not have a diagonal format.
*/
std::shared_ptr<Diag> DiagPtr();
/** @brief Check whether this sparse matrix has COO format. */ /** @brief Check whether this sparse matrix has COO format. */
inline bool HasCOO() const { return coo_ != nullptr; } inline bool HasCOO() const { return coo_ != nullptr; }
...@@ -149,6 +176,8 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -149,6 +176,8 @@ class SparseMatrix : public torch::CustomClassHolder {
inline bool HasCSR() const { return csr_ != nullptr; } inline bool HasCSR() const { return csr_ != nullptr; }
/** @brief Check whether this sparse matrix has CSC format. */ /** @brief Check whether this sparse matrix has CSC format. */
inline bool HasCSC() const { return csc_ != nullptr; } inline bool HasCSC() const { return csc_ != nullptr; }
/** @brief Check whether this sparse matrix has Diag format. */
inline bool HasDiag() const { return diag_ != nullptr; }
/** @return {row, col} tensors in the COO format. */ /** @return {row, col} tensors in the COO format. */
std::tuple<torch::Tensor, torch::Tensor> COOTensors(); std::tuple<torch::Tensor, torch::Tensor> COOTensors();
...@@ -191,9 +220,10 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -191,9 +220,10 @@ class SparseMatrix : public torch::CustomClassHolder {
/** @brief Create the CSC format for the sparse matrix internally */ /** @brief Create the CSC format for the sparse matrix internally */
void _CreateCSC(); void _CreateCSC();
// COO/CSC/CSR pointers. Nullptr indicates non-existence. // COO/CSC/CSR/Diag pointers. Nullptr indicates non-existence.
std::shared_ptr<COO> coo_; std::shared_ptr<COO> coo_;
std::shared_ptr<CSR> csr_, csc_; std::shared_ptr<CSR> csr_, csc_;
std::shared_ptr<Diag> diag_;
// Value of the SparseMatrix // Value of the SparseMatrix
torch::Tensor value_; torch::Tensor value_;
// Shape of the SparseMatrix // Shape of the SparseMatrix
......
...@@ -22,6 +22,10 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd( ...@@ -22,6 +22,10 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A, const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B) { const c10::intrusive_ptr<SparseMatrix>& B) {
ElementwiseOpSanityCheck(A, B); ElementwiseOpSanityCheck(A, B);
if (A->HasDiag() && B->HasDiag()) {
return SparseMatrix::FromDiagPointer(
A->DiagPtr(), A->value() + B->value(), A->shape());
}
auto torch_A = COOToTorchCOO(A->COOPtr(), A->value()); auto torch_A = COOToTorchCOO(A->COOPtr(), A->value());
auto torch_B = COOToTorchCOO(B->COOPtr(), B->value()); auto torch_B = COOToTorchCOO(B->COOPtr(), B->value());
auto sum = (torch_A + torch_B).coalesce(); auto sum = (torch_A + torch_B).coalesce();
......
...@@ -36,6 +36,7 @@ TORCH_LIBRARY(dgl_sparse, m) { ...@@ -36,6 +36,7 @@ TORCH_LIBRARY(dgl_sparse, m) {
m.def("from_coo", &SparseMatrix::FromCOO) m.def("from_coo", &SparseMatrix::FromCOO)
.def("from_csr", &SparseMatrix::FromCSR) .def("from_csr", &SparseMatrix::FromCSR)
.def("from_csc", &SparseMatrix::FromCSC) .def("from_csc", &SparseMatrix::FromCSC)
.def("from_diag", &SparseMatrix::FromDiag)
.def("spsp_add", &SpSpAdd) .def("spsp_add", &SpSpAdd)
.def("reduce", &Reduce) .def("reduce", &Reduce)
.def("sum", &ReduceSum) .def("sum", &ReduceSum)
......
...@@ -99,6 +99,39 @@ std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr) { ...@@ -99,6 +99,39 @@ std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr) {
return CSRFromOldDGLCSR(dgl_csc); return CSRFromOldDGLCSR(dgl_csc);
} }
std::shared_ptr<COO> DiagToCOO(
const std::shared_ptr<Diag>& diag,
const c10::TensorOptions& indices_options) {
int64_t nnz = std::min(diag->num_rows, diag->num_cols);
auto indices = torch::arange(nnz, indices_options).repeat({2, 1});
return std::make_shared<COO>(
COO{diag->num_rows, diag->num_cols, indices, true, true});
}
std::shared_ptr<CSR> DiagToCSR(
const std::shared_ptr<Diag>& diag,
const c10::TensorOptions& indices_options) {
int64_t nnz = std::min(diag->num_rows, diag->num_cols);
auto indptr = torch::full(diag->num_rows + 1, nnz, indices_options);
torch::arange_out(indptr, nnz + 1);
auto indices = torch::arange(nnz, indices_options);
return std::make_shared<CSR>(
CSR{diag->num_rows, diag->num_cols, indptr, indices,
torch::optional<torch::Tensor>(), true});
}
std::shared_ptr<CSR> DiagToCSC(
const std::shared_ptr<Diag>& diag,
const c10::TensorOptions& indices_options) {
int64_t nnz = std::min(diag->num_rows, diag->num_cols);
auto indptr = torch::full(diag->num_cols + 1, nnz, indices_options);
torch::arange_out(indptr, nnz + 1);
auto indices = torch::arange(nnz, indices_options);
return std::make_shared<CSR>(
CSR{diag->num_cols, diag->num_rows, indptr, indices,
torch::optional<torch::Tensor>(), true});
}
std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo) { std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo) {
auto dgl_coo = COOToOldDGLCOO(coo); auto dgl_coo = COOToOldDGLCOO(coo);
auto dgl_coo_tr = aten::COOTranspose(dgl_coo); auto dgl_coo_tr = aten::COOTranspose(dgl_coo);
......
...@@ -17,12 +17,18 @@ namespace sparse { ...@@ -17,12 +17,18 @@ namespace sparse {
SparseMatrix::SparseMatrix( SparseMatrix::SparseMatrix(
const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr, const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
const std::shared_ptr<CSR>& csc, torch::Tensor value, const std::shared_ptr<CSR>& csc, const std::shared_ptr<Diag>& diag,
const std::vector<int64_t>& shape) torch::Tensor value, const std::vector<int64_t>& shape)
: coo_(coo), csr_(csr), csc_(csc), value_(value), shape_(shape) { : coo_(coo),
csr_(csr),
csc_(csc),
diag_(diag),
value_(value),
shape_(shape) {
TORCH_CHECK( TORCH_CHECK(
coo != nullptr || csr != nullptr || csc != nullptr, "At least ", coo != nullptr || csr != nullptr || csc != nullptr || diag != nullptr,
"one of CSR/COO/CSC is required to construct a SparseMatrix.") "At least one of CSR/COO/CSC/Diag is required to construct a "
"SparseMatrix.")
TORCH_CHECK( TORCH_CHECK(
shape.size() == 2, "The shape of a sparse matrix should be ", shape.size() == 2, "The shape of a sparse matrix should be ",
"2-dimensional."); "2-dimensional.");
...@@ -51,24 +57,37 @@ SparseMatrix::SparseMatrix( ...@@ -51,24 +57,37 @@ SparseMatrix::SparseMatrix(
TORCH_CHECK(csc->indptr.device() == value.device()); TORCH_CHECK(csc->indptr.device() == value.device());
TORCH_CHECK(csc->indices.device() == value.device()); TORCH_CHECK(csc->indices.device() == value.device());
} }
if (diag != nullptr) {
TORCH_CHECK(value.size(0) == std::min(diag->num_rows, diag->num_cols));
}
} }
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOOPointer( c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOOPointer(
const std::shared_ptr<COO>& coo, torch::Tensor value, const std::shared_ptr<COO>& coo, torch::Tensor value,
const std::vector<int64_t>& shape) { const std::vector<int64_t>& shape) {
return c10::make_intrusive<SparseMatrix>(coo, nullptr, nullptr, value, shape); return c10::make_intrusive<SparseMatrix>(
coo, nullptr, nullptr, nullptr, value, shape);
} }
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSRPointer( c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSRPointer(
const std::shared_ptr<CSR>& csr, torch::Tensor value, const std::shared_ptr<CSR>& csr, torch::Tensor value,
const std::vector<int64_t>& shape) { const std::vector<int64_t>& shape) {
return c10::make_intrusive<SparseMatrix>(nullptr, csr, nullptr, value, shape); return c10::make_intrusive<SparseMatrix>(
nullptr, csr, nullptr, nullptr, value, shape);
} }
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer( c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer(
const std::shared_ptr<CSR>& csc, torch::Tensor value, const std::shared_ptr<CSR>& csc, torch::Tensor value,
const std::vector<int64_t>& shape) { const std::vector<int64_t>& shape) {
return c10::make_intrusive<SparseMatrix>(nullptr, nullptr, csc, value, shape); return c10::make_intrusive<SparseMatrix>(
nullptr, nullptr, csc, nullptr, value, shape);
}
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiagPointer(
const std::shared_ptr<Diag>& diag, torch::Tensor value,
const std::vector<int64_t>& shape) {
return c10::make_intrusive<SparseMatrix>(
nullptr, nullptr, nullptr, diag, value, shape);
} }
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO( c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO(
...@@ -97,6 +116,12 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSC( ...@@ -97,6 +116,12 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSC(
return SparseMatrix::FromCSCPointer(csc, value, shape); return SparseMatrix::FromCSCPointer(csc, value, shape);
} }
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiag(
torch::Tensor value, const std::vector<int64_t>& shape) {
auto diag = std::make_shared<Diag>(Diag{shape[0], shape[1]});
return SparseMatrix::FromDiagPointer(diag, value, shape);
}
c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike( c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike(
const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) { const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) {
TORCH_CHECK( TORCH_CHECK(
...@@ -136,6 +161,13 @@ std::shared_ptr<CSR> SparseMatrix::CSCPtr() { ...@@ -136,6 +161,13 @@ std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
return csc_; return csc_;
} }
std::shared_ptr<Diag> SparseMatrix::DiagPtr() {
TORCH_CHECK(
diag_ != nullptr,
"Cannot get Diag sparse format from a non-diagonal sparse matrix");
return diag_;
}
std::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() { std::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() {
auto coo = COOPtr(); auto coo = COOPtr();
return std::make_tuple(coo->indices.index({0}), coo->indices.index({1})); return std::make_tuple(coo->indices.index({0}), coo->indices.index({1}));
...@@ -175,7 +207,13 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const { ...@@ -175,7 +207,13 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const {
void SparseMatrix::_CreateCOO() { void SparseMatrix::_CreateCOO() {
if (HasCOO()) return; if (HasCOO()) return;
if (HasCSR()) { if (HasDiag()) {
auto indices_options = torch::TensorOptions()
.dtype(torch::kInt64)
.layout(torch::kStrided)
.device(this->device());
coo_ = DiagToCOO(diag_, indices_options);
} else if (HasCSR()) {
coo_ = CSRToCOO(csr_); coo_ = CSRToCOO(csr_);
} else if (HasCSC()) { } else if (HasCSC()) {
coo_ = CSCToCOO(csc_); coo_ = CSCToCOO(csc_);
...@@ -186,7 +224,13 @@ void SparseMatrix::_CreateCOO() { ...@@ -186,7 +224,13 @@ void SparseMatrix::_CreateCOO() {
void SparseMatrix::_CreateCSR() { void SparseMatrix::_CreateCSR() {
if (HasCSR()) return; if (HasCSR()) return;
if (HasCOO()) { if (HasDiag()) {
auto indices_options = torch::TensorOptions()
.dtype(torch::kInt64)
.layout(torch::kStrided)
.device(this->device());
csr_ = DiagToCSR(diag_, indices_options);
} else if (HasCOO()) {
csr_ = COOToCSR(coo_); csr_ = COOToCSR(coo_);
} else if (HasCSC()) { } else if (HasCSC()) {
csr_ = CSCToCSR(csc_); csr_ = CSCToCSR(csc_);
...@@ -197,7 +241,13 @@ void SparseMatrix::_CreateCSR() { ...@@ -197,7 +241,13 @@ void SparseMatrix::_CreateCSR() {
void SparseMatrix::_CreateCSC() { void SparseMatrix::_CreateCSC() {
if (HasCSC()) return; if (HasCSC()) return;
if (HasCOO()) { if (HasDiag()) {
auto indices_options = torch::TensorOptions()
.dtype(torch::kInt64)
.layout(torch::kStrided)
.device(this->device());
csc_ = DiagToCSC(diag_, indices_options);
} else if (HasCOO()) {
csc_ = COOToCSC(coo_); csc_ = COOToCSC(coo_);
} else if (HasCSR()) { } else if (HasCSR()) {
csc_ = CSRToCSC(csr_); csc_ = CSRToCSC(csr_);
......
...@@ -116,10 +116,49 @@ tensor_list SpSpMMAutoGrad::backward( ...@@ -116,10 +116,49 @@ tensor_list SpSpMMAutoGrad::backward(
return {torch::Tensor(), lhs_val_grad, torch::Tensor(), rhs_val_grad}; return {torch::Tensor(), lhs_val_grad, torch::Tensor(), rhs_val_grad};
} }
c10::intrusive_ptr<SparseMatrix> DiagSpSpMM(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
if (lhs_mat->HasDiag() && rhs_mat->HasDiag()) {
// Diag @ Diag
const int64_t m = lhs_mat->shape()[0];
const int64_t n = lhs_mat->shape()[1];
const int64_t p = rhs_mat->shape()[1];
const int64_t common_diag_len = std::min({m, n, p});
const int64_t new_diag_len = std::min(m, p);
auto slice = torch::indexing::Slice(0, common_diag_len);
auto new_val =
lhs_mat->value().index({slice}) * rhs_mat->value().index({slice});
new_val =
torch::constant_pad_nd(new_val, {0, new_diag_len - common_diag_len}, 0);
return SparseMatrix::FromDiag(new_val, {m, p});
}
if (lhs_mat->HasDiag() && !rhs_mat->HasDiag()) {
// Diag @ Sparse
auto row = rhs_mat->Indices().index({0});
auto val = lhs_mat->value().index_select(0, row) * rhs_mat->value();
return SparseMatrix::ValLike(rhs_mat, val);
}
if (!lhs_mat->HasDiag() && rhs_mat->HasDiag()) {
// Sparse @ Diag
auto col = lhs_mat->Indices().index({1});
auto val = rhs_mat->value().index_select(0, col) * lhs_mat->value();
return SparseMatrix::ValLike(lhs_mat, val);
}
TORCH_CHECK(
false,
"For DiagSpSpMM, at least one of the sparse matries need to have kDiag "
"format");
return c10::intrusive_ptr<SparseMatrix>();
}
c10::intrusive_ptr<SparseMatrix> SpSpMM( c10::intrusive_ptr<SparseMatrix> SpSpMM(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat, const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat) { const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
_SpSpMMSanityCheck(lhs_mat, rhs_mat); _SpSpMMSanityCheck(lhs_mat, rhs_mat);
if (lhs_mat->HasDiag() || rhs_mat->HasDiag()) {
return DiagSpSpMM(lhs_mat, rhs_mat);
}
auto results = SpSpMMAutoGrad::apply( auto results = SpSpMMAutoGrad::apply(
lhs_mat, lhs_mat->value(), rhs_mat, rhs_mat->value()); lhs_mat, lhs_mat->value(), rhs_mat, rhs_mat->value());
std::vector<int64_t> ret_shape({lhs_mat->shape()[0], rhs_mat->shape()[1]}); std::vector<int64_t> ret_shape({lhs_mat->shape()[0], rhs_mat->shape()[1]});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment