/** * Copyright (c) 2023 by Contributors * @file index_select.cc * @brief Index select operators. */ #include "./index_select.h" #include "./macro.h" namespace graphbolt { namespace ops { torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) { if (input.is_pinned() && (index.is_pinned() || index.device().type() == c10::DeviceType::CUDA)) { GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE( c10::DeviceType::CUDA, "UVAIndexSelect", { return UVAIndexSelectImpl(input, index); }); } return input.index({index.to(torch::kLong)}); } } // namespace ops } // namespace graphbolt