/** * Copyright (c) 2022 by Contributors * @file sparse/sparse_format.h * @brief DGL C++ sparse format header. */ #ifndef SPARSE_SPARSE_FORMAT_H_ #define SPARSE_SPARSE_FORMAT_H_ // clang-format off #include // clang-format on #include #include #include #include namespace dgl { namespace sparse { /** @brief SparseFormat enumeration. */ enum SparseFormat { kCOO, kCSR, kCSC, kDiag }; /** @brief COO sparse structure. */ struct COO { /** @brief The shape of the matrix. */ int64_t num_rows = 0, num_cols = 0; /** * @brief COO tensor of shape (2, nnz), stacking the row and column indices. */ torch::Tensor indices; /** @brief Whether the row indices are sorted. */ bool row_sorted = false; /** @brief Whether the column indices per row are sorted. */ bool col_sorted = false; }; /** @brief CSR sparse structure. */ struct CSR { /** @brief The dense shape of the matrix. */ int64_t num_rows = 0, num_cols = 0; /** @brief CSR format index pointer array of the matrix. */ torch::Tensor indptr; /** @brief CSR format index array of the matrix. */ torch::Tensor indices; /** @brief Data index tensor. When it is null, assume it is from 0 to NNZ - 1. */ torch::optional value_indices; /** @brief Whether the column indices per row are sorted. */ bool sorted = false; }; struct Diag { /** @brief The dense shape of the matrix. */ int64_t num_rows = 0, num_cols = 0; }; /** @brief Convert an old DGL COO format to a COO in the sparse library. */ std::shared_ptr COOFromOldDGLCOO(const aten::COOMatrix& dgl_coo); /** @brief Convert a COO in the sparse library to an old DGL COO matrix. */ aten::COOMatrix COOToOldDGLCOO(const std::shared_ptr& coo); /** @brief Convert an old DGL CSR format to a CSR in the sparse library. */ std::shared_ptr CSRFromOldDGLCSR(const aten::CSRMatrix& dgl_csr); /** @brief Convert a CSR in the sparse library to an old DGL CSR matrix. */ aten::CSRMatrix CSRToOldDGLCSR(const std::shared_ptr& csr); /** * @brief Convert a COO and its nonzero values to a Torch COO matrix. * @param coo The COO format in the sparse library * @param value Values of the sparse matrix * * @return Torch Sparse Tensor in COO format */ torch::Tensor COOToTorchCOO( const std::shared_ptr& coo, torch::Tensor value); /** @brief Convert a CSR format to COO format. */ std::shared_ptr CSRToCOO(const std::shared_ptr& csr); /** @brief Convert a CSC format to COO format. */ std::shared_ptr CSCToCOO(const std::shared_ptr& csc); /** @brief Convert a COO format to CSR format. */ std::shared_ptr COOToCSR(const std::shared_ptr& coo); /** @brief Convert a CSC format to CSR format. */ std::shared_ptr CSCToCSR(const std::shared_ptr& csc); /** @brief Convert a COO format to CSC format. */ std::shared_ptr COOToCSC(const std::shared_ptr& coo); /** @brief Convert a CSR format to CSC format. */ std::shared_ptr CSRToCSC(const std::shared_ptr& csr); /** @brief Convert a Diag format to COO format. */ std::shared_ptr DiagToCOO( const std::shared_ptr& diag, const c10::TensorOptions& indices_options); /** @brief Convert a Diag format to CSR format. */ std::shared_ptr DiagToCSR( const std::shared_ptr& diag, const c10::TensorOptions& indices_options); /** @brief Convert a Diag format to CSC format. */ std::shared_ptr DiagToCSC( const std::shared_ptr& diag, const c10::TensorOptions& indices_options); /** @brief COO transposition. */ std::shared_ptr COOTranspose(const std::shared_ptr& coo); /** * @brief Sort the COO matrix by row and column indices. * @return A pair of the sorted COO matrix and the permutation indices. */ std::pair, torch::Tensor> COOSort( const std::shared_ptr& coo); } // namespace sparse } // namespace dgl #endif // SPARSE_SPARSE_FORMAT_H_