// !!! This is a file automatically generated by hipify!!! /** * Copyright (c) 2023 by Contributors * Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) * @file cuda/sort_impl.cu * @brief Sort implementation on CUDA. */ #include #include #include "common.h" #include "utils.h" namespace graphbolt { namespace ops { template std::conditional_t< return_original_positions, std::pair, torch::Tensor> Sort(const scalar_t* input_keys, int64_t num_items, int num_bits) { const auto options = torch::TensorOptions().device(c10::DeviceType::CUDA); constexpr c10::ScalarType dtype = c10::CppTypeToScalarType::value; auto sorted_array = torch::empty(num_items, options.dtype(dtype)); auto sorted_keys = sorted_array.data_ptr(); if (num_bits == 0) { num_bits = sizeof(scalar_t) * 8; } if constexpr (return_original_positions) { // We utilize int64_t for the values array. (torch::kLong == int64_t) auto original_idx = torch::arange(num_items, options.dtype(torch::kLong)); auto sorted_idx = torch::empty_like(original_idx); const int64_t* input_values = original_idx.data_ptr(); int64_t* sorted_values = sorted_idx.data_ptr(); CUB_CALL( DeviceRadixSort::SortPairs, input_keys, sorted_keys, input_values, sorted_values, num_items, 0, num_bits); return std::make_pair(sorted_array, sorted_idx); } else { CUB_CALL( DeviceRadixSort::SortKeys, input_keys, sorted_keys, num_items, 0, num_bits); return sorted_array; } } template std::conditional_t< return_original_positions, std::pair, torch::Tensor> Sort(torch::Tensor input, int num_bits) { return AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "SortImpl", ([&] { return Sort( input.data_ptr(), input.size(0), num_bits); })); } template torch::Tensor Sort(torch::Tensor input, int num_bits); template std::pair Sort( torch::Tensor input, int num_bits); } // namespace ops } // namespace graphbolt