/** * Copyright (c) 2022 by Contributors * @file sparse_matrix.cc * @brief DGL C++ sparse matrix implementations. */ // clang-format off #include // clang-format on #include #include #include #include namespace dgl { namespace sparse { SparseMatrix::SparseMatrix( const std::shared_ptr& coo, const std::shared_ptr& csr, const std::shared_ptr& csc, torch::Tensor value, const std::vector& shape) : coo_(coo), csr_(csr), csc_(csc), value_(value), shape_(shape) { TORCH_CHECK( coo != nullptr || csr != nullptr || csc != nullptr, "At least ", "one of CSR/COO/CSC is required to construct a SparseMatrix.") TORCH_CHECK( shape.size() == 2, "The shape of a sparse matrix should be ", "2-dimensional."); // NOTE: Currently all the tensors of a SparseMatrix should on the same // device. Do we allow the graph structure and values are on different // devices? if (coo != nullptr) { TORCH_CHECK(coo->indices.dim() == 2); TORCH_CHECK(coo->indices.size(0) == 2); TORCH_CHECK(coo->indices.size(1) == value.size(0)); TORCH_CHECK(coo->indices.device() == value.device()); } if (csr != nullptr) { TORCH_CHECK(csr->indptr.dim() == 1); TORCH_CHECK(csr->indices.dim() == 1); TORCH_CHECK(csr->indptr.size(0) == shape[0] + 1); TORCH_CHECK(csr->indices.size(0) == value.size(0)); TORCH_CHECK(csr->indptr.device() == value.device()); TORCH_CHECK(csr->indices.device() == value.device()); } if (csc != nullptr) { TORCH_CHECK(csc->indptr.dim() == 1); TORCH_CHECK(csc->indices.dim() == 1); TORCH_CHECK(csc->indptr.size(0) == shape[1] + 1); TORCH_CHECK(csc->indices.size(0) == value.size(0)); TORCH_CHECK(csc->indptr.device() == value.device()); TORCH_CHECK(csc->indices.device() == value.device()); } } c10::intrusive_ptr SparseMatrix::FromCOOPointer( const std::shared_ptr& coo, torch::Tensor value, const std::vector& shape) { return c10::make_intrusive(coo, nullptr, nullptr, value, shape); } c10::intrusive_ptr SparseMatrix::FromCSRPointer( const std::shared_ptr& csr, torch::Tensor value, const std::vector& shape) { return c10::make_intrusive(nullptr, csr, nullptr, value, shape); } c10::intrusive_ptr SparseMatrix::FromCSCPointer( const std::shared_ptr& csc, torch::Tensor value, const std::vector& shape) { return c10::make_intrusive(nullptr, nullptr, csc, value, shape); } c10::intrusive_ptr SparseMatrix::FromCOO( torch::Tensor indices, torch::Tensor value, const std::vector& shape) { auto coo = std::make_shared(COO{shape[0], shape[1], indices, false, false}); return SparseMatrix::FromCOOPointer(coo, value, shape); } c10::intrusive_ptr SparseMatrix::FromCSR( torch::Tensor indptr, torch::Tensor indices, torch::Tensor value, const std::vector& shape) { auto csr = std::make_shared( CSR{shape[0], shape[1], indptr, indices, torch::optional(), false}); return SparseMatrix::FromCSRPointer(csr, value, shape); } c10::intrusive_ptr SparseMatrix::FromCSC( torch::Tensor indptr, torch::Tensor indices, torch::Tensor value, const std::vector& shape) { auto csc = std::make_shared( CSR{shape[1], shape[0], indptr, indices, torch::optional(), false}); return SparseMatrix::FromCSCPointer(csc, value, shape); } c10::intrusive_ptr SparseMatrix::ValLike( const c10::intrusive_ptr& mat, torch::Tensor value) { TORCH_CHECK( mat->value().size(0) == value.size(0), "The first dimension of ", "the old values and the new values must be the same."); TORCH_CHECK( mat->value().device() == value.device(), "The device of the ", "old values and the new values must be the same."); auto shape = mat->shape(); if (mat->HasCOO()) { return SparseMatrix::FromCOOPointer(mat->COOPtr(), value, shape); } else if (mat->HasCSR()) { return SparseMatrix::FromCSRPointer(mat->CSRPtr(), value, shape); } else { return SparseMatrix::FromCSCPointer(mat->CSCPtr(), value, shape); } } std::shared_ptr SparseMatrix::COOPtr() { if (coo_ == nullptr) { _CreateCOO(); } return coo_; } std::shared_ptr SparseMatrix::CSRPtr() { if (csr_ == nullptr) { _CreateCSR(); } return csr_; } std::shared_ptr SparseMatrix::CSCPtr() { if (csc_ == nullptr) { _CreateCSC(); } return csc_; } std::tuple SparseMatrix::COOTensors() { auto coo = COOPtr(); return std::make_tuple(coo->indices.index({0}), coo->indices.index({1})); } torch::Tensor SparseMatrix::Indices() { auto coo = COOPtr(); return coo->indices; } std::tuple> SparseMatrix::CSRTensors() { auto csr = CSRPtr(); auto val = value(); return std::make_tuple(csr->indptr, csr->indices, csr->value_indices); } std::tuple> SparseMatrix::CSCTensors() { auto csc = CSCPtr(); return std::make_tuple(csc->indptr, csc->indices, csc->value_indices); } c10::intrusive_ptr SparseMatrix::Transpose() const { auto shape = shape_; std::swap(shape[0], shape[1]); auto value = value_; if (HasCOO()) { auto coo = COOTranspose(coo_); return SparseMatrix::FromCOOPointer(coo, value, shape); } else if (HasCSR()) { return SparseMatrix::FromCSCPointer(csr_, value, shape); } else { return SparseMatrix::FromCSRPointer(csc_, value, shape); } } void SparseMatrix::_CreateCOO() { if (HasCOO()) return; if (HasCSR()) { coo_ = CSRToCOO(csr_); } else if (HasCSC()) { coo_ = CSCToCOO(csc_); } else { LOG(FATAL) << "SparseMatrix does not have any sparse format"; } } void SparseMatrix::_CreateCSR() { if (HasCSR()) return; if (HasCOO()) { csr_ = COOToCSR(coo_); } else if (HasCSC()) { csr_ = CSCToCSR(csc_); } else { LOG(FATAL) << "SparseMatrix does not have any sparse format"; } } void SparseMatrix::_CreateCSC() { if (HasCSC()) return; if (HasCOO()) { csc_ = COOToCSC(coo_); } else if (HasCSR()) { csc_ = CSRToCSC(csr_); } else { LOG(FATAL) << "SparseMatrix does not have any sparse format"; } } } // namespace sparse } // namespace dgl