sort_impl.cu 2.34 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/sort_impl.cu
 * @brief Sort implementation on CUDA.
 */
#include <c10/core/ScalarType.h>

#include <cub/cub.cuh>

#include "./common.h"
#include "./utils.h"

namespace graphbolt {
namespace ops {

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
template <bool return_original_positions, typename scalar_t>
std::conditional_t<
    return_original_positions, std::pair<torch::Tensor, torch::Tensor>,
    torch::Tensor>
Sort(const scalar_t* input_keys, int64_t num_items, int num_bits) {
  const auto options = torch::TensorOptions().device(c10::DeviceType::CUDA);
  constexpr c10::ScalarType dtype = c10::CppTypeToScalarType<scalar_t>::value;
  auto sorted_array = torch::empty(num_items, options.dtype(dtype));
  auto sorted_keys = sorted_array.data_ptr<scalar_t>();
  if (num_bits == 0) {
    num_bits = sizeof(scalar_t) * 8;
  }

  if constexpr (return_original_positions) {
    // We utilize int64_t for the values array. (torch::kLong == int64_t)
    auto original_idx = torch::arange(num_items, options.dtype(torch::kLong));
    auto sorted_idx = torch::empty_like(original_idx);
    const int64_t* input_values = original_idx.data_ptr<int64_t>();
    int64_t* sorted_values = sorted_idx.data_ptr<int64_t>();
36
37
38
    CUB_CALL(
        DeviceRadixSort::SortPairs, input_keys, sorted_keys, input_values,
        sorted_values, num_items, 0, num_bits);
39
40
    return std::make_pair(sorted_array, sorted_idx);
  } else {
41
42
43
    CUB_CALL(
        DeviceRadixSort::SortKeys, input_keys, sorted_keys, num_items, 0,
        num_bits);
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    return sorted_array;
  }
}

template <bool return_original_positions>
std::conditional_t<
    return_original_positions, std::pair<torch::Tensor, torch::Tensor>,
    torch::Tensor>
Sort(torch::Tensor input, int num_bits) {
  return AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "SortImpl", ([&] {
                                      return Sort<return_original_positions>(
                                          input.data_ptr<scalar_t>(),
                                          input.size(0), num_bits);
                                    }));
58
59
}

60
61
62
63
template torch::Tensor Sort<false>(torch::Tensor input, int num_bits);
template std::pair<torch::Tensor, torch::Tensor> Sort<true>(
    torch::Tensor input, int num_bits);

64
65
}  //  namespace ops
}  //  namespace graphbolt