Unverified Commit 11c866ab authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 0d687968
/*! /**
* Copyright (c) 2022 by Contributors * Copyright (c) 2022 by Contributors
* \file sparse/elementwise_op.h * @file sparse/elementwise_op.h
* \brief DGL C++ sparse elementwise operators * @brief DGL C++ sparse elementwise operators
*/ */
#ifndef SPARSE_ELEMENTWISE_OP_H_ #ifndef SPARSE_ELEMENTWISE_OP_H_
#define SPARSE_ELEMENTWISE_OP_H_ #define SPARSE_ELEMENTWISE_OP_H_
...@@ -13,13 +13,13 @@ namespace dgl { ...@@ -13,13 +13,13 @@ namespace dgl {
namespace sparse { namespace sparse {
// TODO(zhenkun): support addition of matrices with different sparsity. // TODO(zhenkun): support addition of matrices with different sparsity.
/*! /**
* @brief Adds two sparse matrices. Currently does not support two matrices with * @brief Adds two sparse matrices. Currently does not support two matrices with
* different sparsity. * different sparsity.
* *
* @param A SparseMatrix * @param A SparseMatrix
* @param B SparseMatrix * @param B SparseMatrix
* *
* @return SparseMatrix * @return SparseMatrix
*/ */
c10::intrusive_ptr<SparseMatrix> SpSpAdd( c10::intrusive_ptr<SparseMatrix> SpSpAdd(
......
/*! /**
* Copyright (c) 2022 by Contributors * Copyright (c) 2022 by Contributors
* @file sparse/sparse_matrix.h * @file sparse/sparse_matrix.h
* @brief DGL C++ sparse matrix header * @brief DGL C++ sparse matrix header
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
namespace dgl { namespace dgl {
namespace sparse { namespace sparse {
/*! @brief SparseFormat enumeration */ /** @brief SparseFormat enumeration */
enum SparseFormat { kCOO, kCSR, kCSC }; enum SparseFormat { kCOO, kCSR, kCSC };
/*! @brief CSR sparse structure */ /** @brief CSR sparse structure */
struct CSR { struct CSR {
// CSR format index pointer array of the matrix // CSR format index pointer array of the matrix
torch::Tensor indptr; torch::Tensor indptr;
...@@ -34,7 +34,7 @@ struct CSR { ...@@ -34,7 +34,7 @@ struct CSR {
torch::optional<torch::Tensor> value_indices; torch::optional<torch::Tensor> value_indices;
}; };
/*! @brief COO sparse structure */ /** @brief COO sparse structure */
struct COO { struct COO {
// COO format row array of the matrix // COO format row array of the matrix
torch::Tensor row; torch::Tensor row;
...@@ -42,10 +42,10 @@ struct COO { ...@@ -42,10 +42,10 @@ struct COO {
torch::Tensor col; torch::Tensor col;
}; };
/*! @brief SparseMatrix bound to Python */ /** @brief SparseMatrix bound to Python */
class SparseMatrix : public torch::CustomClassHolder { class SparseMatrix : public torch::CustomClassHolder {
public: public:
/*! /**
* @brief General constructor to construct a sparse matrix for different * @brief General constructor to construct a sparse matrix for different
* sparse formats. At least one of the sparse formats should be provided, * sparse formats. At least one of the sparse formats should be provided,
* while others could be nullptrs. * while others could be nullptrs.
...@@ -61,7 +61,7 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -61,7 +61,7 @@ class SparseMatrix : public torch::CustomClassHolder {
const std::shared_ptr<CSR>& csc, torch::Tensor value, const std::shared_ptr<CSR>& csc, torch::Tensor value,
const std::vector<int64_t>& shape); const std::vector<int64_t>& shape);
/*! /**
* @brief Construct a SparseMatrix from a COO format. * @brief Construct a SparseMatrix from a COO format.
* @param coo The COO format * @param coo The COO format
* @param value Values of the sparse matrix * @param value Values of the sparse matrix
...@@ -73,7 +73,7 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -73,7 +73,7 @@ class SparseMatrix : public torch::CustomClassHolder {
const std::shared_ptr<COO>& coo, torch::Tensor value, const std::shared_ptr<COO>& coo, torch::Tensor value,
const std::vector<int64_t>& shape); const std::vector<int64_t>& shape);
/*! /**
* @brief Construct a SparseMatrix from a CSR format. * @brief Construct a SparseMatrix from a CSR format.
* @param csr The CSR format * @param csr The CSR format
* @param value Values of the sparse matrix * @param value Values of the sparse matrix
...@@ -85,7 +85,7 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -85,7 +85,7 @@ class SparseMatrix : public torch::CustomClassHolder {
const std::shared_ptr<CSR>& csr, torch::Tensor value, const std::shared_ptr<CSR>& csr, torch::Tensor value,
const std::vector<int64_t>& shape); const std::vector<int64_t>& shape);
/*! /**
* @brief Construct a SparseMatrix from a CSC format. * @brief Construct a SparseMatrix from a CSC format.
* @param csc The CSC format * @param csc The CSC format
* @param value Values of the sparse matrix * @param value Values of the sparse matrix
...@@ -97,44 +97,44 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -97,44 +97,44 @@ class SparseMatrix : public torch::CustomClassHolder {
const std::shared_ptr<CSR>& csc, torch::Tensor value, const std::shared_ptr<CSR>& csc, torch::Tensor value,
const std::vector<int64_t>& shape); const std::vector<int64_t>& shape);
/*! @return Value of the sparse matrix. */ /** @return Value of the sparse matrix. */
inline torch::Tensor value() const { return value_; } inline torch::Tensor value() const { return value_; }
/*! @return Shape of the sparse matrix. */ /** @return Shape of the sparse matrix. */
inline const std::vector<int64_t>& shape() const { return shape_; } inline const std::vector<int64_t>& shape() const { return shape_; }
/*! @return Number of non-zero values */ /** @return Number of non-zero values */
inline int64_t nnz() const { return value_.size(0); } inline int64_t nnz() const { return value_.size(0); }
/*! @return Non-zero value data type */ /** @return Non-zero value data type */
inline caffe2::TypeMeta dtype() const { return value_.dtype(); } inline caffe2::TypeMeta dtype() const { return value_.dtype(); }
/*! @return Device of the sparse matrix */ /** @return Device of the sparse matrix */
inline torch::Device device() const { return value_.device(); } inline torch::Device device() const { return value_.device(); }
/*! @return COO of the sparse matrix. The COO is created if not exists. */ /** @return COO of the sparse matrix. The COO is created if not exists. */
std::shared_ptr<COO> COOPtr(); std::shared_ptr<COO> COOPtr();
/*! @return CSR of the sparse matrix. The CSR is created if not exists. */ /** @return CSR of the sparse matrix. The CSR is created if not exists. */
std::shared_ptr<CSR> CSRPtr(); std::shared_ptr<CSR> CSRPtr();
/*! @return CSC of the sparse matrix. The CSC is created if not exists. */ /** @return CSC of the sparse matrix. The CSC is created if not exists. */
std::shared_ptr<CSR> CSCPtr(); std::shared_ptr<CSR> CSCPtr();
/*! @brief Check whether this sparse matrix has COO format. */ /** @brief Check whether this sparse matrix has COO format. */
inline bool HasCOO() const { return coo_ != nullptr; } inline bool HasCOO() const { return coo_ != nullptr; }
/*! @brief Check whether this sparse matrix has CSR format. */ /** @brief Check whether this sparse matrix has CSR format. */
inline bool HasCSR() const { return csr_ != nullptr; } inline bool HasCSR() const { return csr_ != nullptr; }
/*! @brief Check whether this sparse matrix has CSC format. */ /** @brief Check whether this sparse matrix has CSC format. */
inline bool HasCSC() const { return csc_ != nullptr; } inline bool HasCSC() const { return csc_ != nullptr; }
/*! @return {row, col, value} tensors in the COO format. */ /** @return {row, col, value} tensors in the COO format. */
std::vector<torch::Tensor> COOTensors(); std::vector<torch::Tensor> COOTensors();
/*! @return {row, col, value} tensors in the CSR format. */ /** @return {row, col, value} tensors in the CSR format. */
std::vector<torch::Tensor> CSRTensors(); std::vector<torch::Tensor> CSRTensors();
/*! @return {row, col, value} tensors in the CSC format. */ /** @return {row, col, value} tensors in the CSC format. */
std::vector<torch::Tensor> CSCTensors(); std::vector<torch::Tensor> CSCTensors();
private: private:
/*! @brief Create the COO format for the sparse matrix internally */ /** @brief Create the COO format for the sparse matrix internally */
void _CreateCOO(); void _CreateCOO();
/*! @brief Create the CSR format for the sparse matrix internally */ /** @brief Create the CSR format for the sparse matrix internally */
void _CreateCSR(); void _CreateCSR();
/*! @brief Create the CSC format for the sparse matrix internally */ /** @brief Create the CSC format for the sparse matrix internally */
void _CreateCSC(); void _CreateCSC();
// COO/CSC/CSR pointers. Nullptr indicates non-existence. // COO/CSC/CSR pointers. Nullptr indicates non-existence.
...@@ -146,7 +146,7 @@ class SparseMatrix : public torch::CustomClassHolder { ...@@ -146,7 +146,7 @@ class SparseMatrix : public torch::CustomClassHolder {
const std::vector<int64_t> shape_; const std::vector<int64_t> shape_;
}; };
/*! /**
* @brief Create a SparseMatrix from tensors in COO format. * @brief Create a SparseMatrix from tensors in COO format.
* @param row Row indices of the COO. * @param row Row indices of the COO.
* @param col Column indices of the COO. * @param col Column indices of the COO.
...@@ -159,7 +159,7 @@ c10::intrusive_ptr<SparseMatrix> CreateFromCOO( ...@@ -159,7 +159,7 @@ c10::intrusive_ptr<SparseMatrix> CreateFromCOO(
torch::Tensor row, torch::Tensor col, torch::Tensor value, torch::Tensor row, torch::Tensor col, torch::Tensor value,
const std::vector<int64_t>& shape); const std::vector<int64_t>& shape);
/*! /**
* @brief Create a SparseMatrix from tensors in CSR format. * @brief Create a SparseMatrix from tensors in CSR format.
* @param indptr Index pointer array of the CSR * @param indptr Index pointer array of the CSR
* @param indices Indices array of the CSR * @param indices Indices array of the CSR
...@@ -172,7 +172,7 @@ c10::intrusive_ptr<SparseMatrix> CreateFromCSR( ...@@ -172,7 +172,7 @@ c10::intrusive_ptr<SparseMatrix> CreateFromCSR(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor value, torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
const std::vector<int64_t>& shape); const std::vector<int64_t>& shape);
/*! /**
* @brief Create a SparseMatrix from tensors in CSC format. * @brief Create a SparseMatrix from tensors in CSC format.
* @param indptr Index pointer array of the CSC * @param indptr Index pointer array of the CSC
* @param indices Indices array of the CSC * @param indices Indices array of the CSC
......
/*! /**
* Copyright (c) 2022 by Contributors * Copyright (c) 2022 by Contributors
* @file elementwise_op.cc * @file elementwise_op.cc
* @brief DGL C++ sparse elementwise operator implementation * @brief DGL C++ sparse elementwise operator implementation
......
/*! /**
* Copyright (c) 2022 by Contributors * Copyright (c) 2022 by Contributors
* @file python_binding.cc * @file python_binding.cc
* @brief DGL sparse library Python binding * @brief DGL sparse library Python binding
......
/*! /**
* Copyright (c) 2022 by Contributors * Copyright (c) 2022 by Contributors
* @file sparse_matrix.cc * @file sparse_matrix.cc
* @brief DGL C++ sparse matrix implementations * @brief DGL C++ sparse matrix implementations
......
/*! /**
* Copyright (c) 2022 by Contributors * Copyright (c) 2022 by Contributors
* @file utils.h * @file utils.h
* @brief DGL C++ sparse API utilities * @brief DGL C++ sparse API utilities
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
namespace dgl { namespace dgl {
namespace sparse { namespace sparse {
/*! @brief Find a proper sparse format for two sparse matrices. It chooses /** @brief Find a proper sparse format for two sparse matrices. It chooses
* COO if anyone of the sparse matrices has COO format. If none of them has * COO if anyone of the sparse matrices has COO format. If none of them has
* COO, it tries CSR and CSC in the same manner. */ * COO, it tries CSR and CSC in the same manner. */
inline static SparseFormat FindAnyExistingFormat( inline static SparseFormat FindAnyExistingFormat(
...@@ -29,7 +29,7 @@ inline static SparseFormat FindAnyExistingFormat( ...@@ -29,7 +29,7 @@ inline static SparseFormat FindAnyExistingFormat(
return fmt; return fmt;
} }
/*! @brief Check whether two matrices has the same dtype and shape for /** @brief Check whether two matrices has the same dtype and shape for
* elementwise operators. */ * elementwise operators. */
inline static void ElementwiseOpSanityCheck( inline static void ElementwiseOpSanityCheck(
const c10::intrusive_ptr<SparseMatrix>& A, const c10::intrusive_ptr<SparseMatrix>& A,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment