"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "0da60e10ac8be54e1220ea2855b92eccec5e4d6a"
array_sort.hip 1.92 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2020 by Contributors
5
6
 * @file array/cpu/array_sort.cu
 * @brief Array sort GPU implementation
7
8
 */
#include <dgl/array.h>
sangwzh's avatar
sangwzh committed
9
#include "../../../include/dgl/array.h"
10

sangwzh's avatar
sangwzh committed
11
12

#include <hipcub/hipcub.hpp>
13

14
#include "../../runtime/cuda/cuda_common.h"
sangwzh's avatar
sangwzh committed
15
#include "utils.h"
16
17
18
19
20
21

namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {

22
template <DGLDeviceType XPU, typename IdType>
23
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
24
25
26
27
28
29
30
31
32
33
34
35
  const auto& ctx = array->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  const int64_t nitems = array->shape[0];
  IdArray orig_idx = Range(0, nitems, 64, ctx);
  IdArray sorted_array = NewIdArray(nitems, ctx, array->dtype.bits);
  IdArray sorted_idx = NewIdArray(nitems, ctx, 64);

  const IdType* keys_in = array.Ptr<IdType>();
  const int64_t* values_in = orig_idx.Ptr<int64_t>();
  IdType* keys_out = sorted_array.Ptr<IdType>();
  int64_t* values_out = sorted_idx.Ptr<int64_t>();

sangwzh's avatar
sangwzh committed
36
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
37
  if (num_bits == 0) {
38
    num_bits = sizeof(IdType) * 8;
39
40
  }

41
42
  // Allocate workspace
  size_t workspace_size = 0;
sangwzh's avatar
sangwzh committed
43
  CUDA_CALL(hipcub::DeviceRadixSort::SortPairs(
44
45
      nullptr, workspace_size, keys_in, keys_out, values_in, values_out, nitems,
      0, num_bits, stream));
46
47
48
  void* workspace = device->AllocWorkspace(ctx, workspace_size);

  // Compute
sangwzh's avatar
sangwzh committed
49
  CUDA_CALL(hipcub::DeviceRadixSort::SortPairs(
50
51
      workspace, workspace_size, keys_in, keys_out, values_in, values_out,
      nitems, 0, num_bits, stream));
52
53
54
55
56
57

  device->FreeWorkspace(ctx, workspace);

  return std::make_pair(sorted_array, sorted_idx);
}

58
59
60
61
template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int32_t>(
    IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int64_t>(
    IdArray, int num_bits);
62
63
64
65

}  // namespace impl
}  // namespace aten
}  // namespace dgl