/** * Copyright (c) 2023 by Contributors * @file utils.h * @brief Graphbolt utils. */ #ifndef GRAPHBOLT_UTILS_H_ #define GRAPHBOLT_UTILS_H_ #include namespace graphbolt { namespace utils { /** * @brief Checks whether the tensor is stored on the GPU or the pinned memory. */ inline bool is_accessible_from_gpu(torch::Tensor tensor) { return tensor.is_pinned() || tensor.device().type() == c10::DeviceType::CUDA; } /** * @brief Retrieves the value of the tensor at the given index. * * @note If the tensor is not contiguous, it will be copied to a contiguous * tensor. * * @tparam T The type of the tensor. * @param tensor The tensor. * @param index The index. * * @return T The value of the tensor at the given index. */ template T GetValueByIndex(const torch::Tensor& tensor, int64_t index) { TORCH_CHECK( index >= 0 && index < tensor.numel(), "The index should be within the range of the tensor, but got index ", index, " and tensor size ", tensor.numel()); auto contiguous_tensor = tensor.contiguous(); auto data_ptr = contiguous_tensor.data_ptr(); return data_ptr[index]; } } // namespace utils } // namespace graphbolt #endif // GRAPHBOLT_UTILS_H_