"src/diffusers/models/unet_3d_condition.py" did not exist on "88fa6b7d68d77b2531462ebe5a339b8c5b034ce4"
sort_impl.hip 2.4 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
3
4
5
6
7
8
9
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/sort_impl.cu
 * @brief Sort implementation on CUDA.
 */
#include <c10/core/ScalarType.h>

sangwzh's avatar
sangwzh committed
10
#include <hipcub/hipcub.hpp>
11

sangwzh's avatar
sangwzh committed
12
13
#include "common.h"
#include "utils.h"
14
15
16
17

namespace graphbolt {
namespace ops {

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
template <bool return_original_positions, typename scalar_t>
std::conditional_t<
    return_original_positions, std::pair<torch::Tensor, torch::Tensor>,
    torch::Tensor>
Sort(const scalar_t* input_keys, int64_t num_items, int num_bits) {
  const auto options = torch::TensorOptions().device(c10::DeviceType::CUDA);
  constexpr c10::ScalarType dtype = c10::CppTypeToScalarType<scalar_t>::value;
  auto sorted_array = torch::empty(num_items, options.dtype(dtype));
  auto sorted_keys = sorted_array.data_ptr<scalar_t>();
  if (num_bits == 0) {
    num_bits = sizeof(scalar_t) * 8;
  }

  if constexpr (return_original_positions) {
    // We utilize int64_t for the values array. (torch::kLong == int64_t)
    auto original_idx = torch::arange(num_items, options.dtype(torch::kLong));
    auto sorted_idx = torch::empty_like(original_idx);
    const int64_t* input_values = original_idx.data_ptr<int64_t>();
    int64_t* sorted_values = sorted_idx.data_ptr<int64_t>();
37
38
39
    CUB_CALL(
        DeviceRadixSort::SortPairs, input_keys, sorted_keys, input_values,
        sorted_values, num_items, 0, num_bits);
40
41
    return std::make_pair(sorted_array, sorted_idx);
  } else {
42
43
44
    CUB_CALL(
        DeviceRadixSort::SortKeys, input_keys, sorted_keys, num_items, 0,
        num_bits);
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    return sorted_array;
  }
}

template <bool return_original_positions>
std::conditional_t<
    return_original_positions, std::pair<torch::Tensor, torch::Tensor>,
    torch::Tensor>
Sort(torch::Tensor input, int num_bits) {
  return AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "SortImpl", ([&] {
                                      return Sort<return_original_positions>(
                                          input.data_ptr<scalar_t>(),
                                          input.size(0), num_bits);
                                    }));
59
60
}

61
62
63
64
template torch::Tensor Sort<false>(torch::Tensor input, int num_bits);
template std::pair<torch::Tensor, torch::Tensor> Sort<true>(
    torch::Tensor input, int num_bits);

65
66
}  //  namespace ops
}  //  namespace graphbolt