/** * Copyright (c) 2022 by Contributors * @file sparse/sparse_matrix.h * @brief DGL C++ sparse matrix header. */ #ifndef SPARSE_SPARSE_MATRIX_H_ #define SPARSE_SPARSE_MATRIX_H_ // clang-format off #include // clang-format on #include #include #include #include #include #include #include namespace dgl { namespace sparse { /** @brief SparseMatrix bound to Python. */ class SparseMatrix : public torch::CustomClassHolder { public: /** * @brief General constructor to construct a sparse matrix for different * sparse formats. At least one of the sparse formats should be provided, * while others could be nullptrs. * * @param coo The COO format. * @param csr The CSR format. * @param csc The CSC format. * @param value Value of the sparse matrix. * @param shape Shape of the sparse matrix. */ SparseMatrix( const std::shared_ptr& coo, const std::shared_ptr& csr, const std::shared_ptr& csc, torch::Tensor value, const std::vector& shape); /** * @brief Construct a SparseMatrix from a COO format. * @param coo The COO format * @param value Values of the sparse matrix * @param shape Shape of the sparse matrix * * @return SparseMatrix */ static c10::intrusive_ptr FromCOO( const std::shared_ptr& coo, torch::Tensor value, const std::vector& shape); /** * @brief Construct a SparseMatrix from a CSR format. * @param csr The CSR format * @param value Values of the sparse matrix * @param shape Shape of the sparse matrix * * @return SparseMatrix */ static c10::intrusive_ptr FromCSR( const std::shared_ptr& csr, torch::Tensor value, const std::vector& shape); /** * @brief Construct a SparseMatrix from a CSC format. * @param csc The CSC format * @param value Values of the sparse matrix * @param shape Shape of the sparse matrix * * @return SparseMatrix */ static c10::intrusive_ptr FromCSC( const std::shared_ptr& csc, torch::Tensor value, const std::vector& shape); /** @return Value of the sparse matrix. */ inline torch::Tensor value() const { return value_; } /** @return Shape of the sparse matrix. */ inline const std::vector& shape() const { return shape_; } /** @return Number of non-zero values */ inline int64_t nnz() const { return value_.size(0); } /** @return Non-zero value data type */ inline caffe2::TypeMeta dtype() const { return value_.dtype(); } /** @return Device of the sparse matrix */ inline torch::Device device() const { return value_.device(); } /** @return COO of the sparse matrix. The COO is created if not exists. */ std::shared_ptr COOPtr(); /** @return CSR of the sparse matrix. The CSR is created if not exists. */ std::shared_ptr CSRPtr(); /** @return CSC of the sparse matrix. The CSC is created if not exists. */ std::shared_ptr CSCPtr(); /** @brief Check whether this sparse matrix has COO format. */ inline bool HasCOO() const { return coo_ != nullptr; } /** @brief Check whether this sparse matrix has CSR format. */ inline bool HasCSR() const { return csr_ != nullptr; } /** @brief Check whether this sparse matrix has CSC format. */ inline bool HasCSC() const { return csc_ != nullptr; } /** @return {row, col} tensors in the COO format. */ std::tuple COOTensors(); /** @return {row, col, value_indices} tensors in the CSR format. */ std::tuple> CSRTensors(); /** @return {row, col, value_indices} tensors in the CSC format. */ std::tuple> CSCTensors(); /** @brief Return the transposition of the sparse matrix. It transposes the * first existing sparse format by checking COO, CSR, and CSC. */ c10::intrusive_ptr Transpose() const; private: /** @brief Create the COO format for the sparse matrix internally */ void _CreateCOO(); /** @brief Create the CSR format for the sparse matrix internally */ void _CreateCSR(); /** @brief Create the CSC format for the sparse matrix internally */ void _CreateCSC(); // COO/CSC/CSR pointers. Nullptr indicates non-existence. std::shared_ptr coo_; std::shared_ptr csr_, csc_; // Value of the SparseMatrix torch::Tensor value_; // Shape of the SparseMatrix const std::vector shape_; }; /** * @brief Create a SparseMatrix from tensors in COO format. * @param row Row indices of the COO. * @param col Column indices of the COO. * @param value Values of the sparse matrix. * @param shape Shape of the sparse matrix. * * @return SparseMatrix */ c10::intrusive_ptr CreateFromCOO( torch::Tensor row, torch::Tensor col, torch::Tensor value, const std::vector& shape); /** * @brief Create a SparseMatrix from tensors in CSR format. * @param indptr Index pointer array of the CSR * @param indices Indices array of the CSR * @param value Values of the sparse matrix * @param shape Shape of the sparse matrix * * @return SparseMatrix */ c10::intrusive_ptr CreateFromCSR( torch::Tensor indptr, torch::Tensor indices, torch::Tensor value, const std::vector& shape); /** * @brief Create a SparseMatrix from tensors in CSC format. * @param indptr Index pointer array of the CSC * @param indices Indices array of the CSC * @param value Values of the sparse matrix * @param shape Shape of the sparse matrix * * @return SparseMatrix */ c10::intrusive_ptr CreateFromCSC( torch::Tensor indptr, torch::Tensor indices, torch::Tensor value, const std::vector& shape); /** * @brief Create a SparseMatrix from a SparseMatrix using new values. * @param mat An existing sparse matrix * @param value New values of the sparse matrix * * @return SparseMatrix */ c10::intrusive_ptr CreateValLike( const c10::intrusive_ptr& mat, torch::Tensor value); } // namespace sparse } // namespace dgl #endif // SPARSE_SPARSE_MATRIX_H_