/** * Copyright (c) 2022 by Contributors * @file utils.h * @brief DGL C++ sparse API utilities */ #ifndef DGL_SPARSE_UTILS_H_ #define DGL_SPARSE_UTILS_H_ // clang-format off #include // clang-format on #include #include #include #include namespace dgl { namespace sparse { /** @brief Find a proper sparse format for two sparse matrices. It chooses * COO if anyone of the sparse matrices has COO format. If none of them has * COO, it tries CSR and CSC in the same manner. */ inline static SparseFormat FindAnyExistingFormat( const c10::intrusive_ptr& A, const c10::intrusive_ptr& B) { SparseFormat fmt; if (A->HasCOO() || B->HasCOO()) { fmt = SparseFormat::kCOO; } else if (A->HasCSR() || B->HasCSR()) { fmt = SparseFormat::kCSR; } else { fmt = SparseFormat::kCSC; } return fmt; } /** @brief Check whether two matrices has the same dtype and shape for * elementwise operators. */ inline static void ElementwiseOpSanityCheck( const c10::intrusive_ptr& A, const c10::intrusive_ptr& B) { CHECK(A->value().dtype() == B->value().dtype()) << "Elementwise operators do not support two sparse matrices with " "different dtypes. (" << A->value().dtype() << " vs " << B->value().dtype() << ")"; CHECK(A->shape()[0] == B->shape()[0] && A->shape()[1] == B->shape()[1]) << "Elementwise operator do not support two sparse matrices with " "different shapes. ([" << A->shape()[0] << ", " << A->shape()[1] << "] vs [" << B->shape()[0] << ", " << B->shape()[1] << "])"; } /** @brief Convert a Torch tensor to a DGL array. */ inline static runtime::NDArray TorchTensorToDGLArray(torch::Tensor tensor) { return runtime::DLPackConvert::FromDLPack(at::toDLPack(tensor)); } /** @brief Convert a DGL array to a Torch tensor. */ inline static torch::Tensor DGLArrayToTorchTensor(runtime::NDArray array) { return at::fromDLPack(runtime::DLPackConvert::ToDLPack(array)); } /** @brief Convert an optional Torch tensor to a DGL array. */ inline static runtime::NDArray OptionalTorchTensorToDGLArray( torch::optional tensor) { if (!tensor.has_value()) { return aten::NullArray(); } return TorchTensorToDGLArray(tensor.value()); } /** @brief Convert a DGL array to an optional Torch tensor. */ inline static torch::optional DGLArrayToOptionalTorchTensor( runtime::NDArray array) { if (aten::IsNullArray(array)) { return torch::optional(); } return torch::make_optional(DGLArrayToTorchTensor(array)); } } // namespace sparse } // namespace dgl #endif // DGL_SPARSE_UTILS_H_