array_cumsum.hip 1.77 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2020 by Contributors
5
6
 * @file array/cpu/array_cumsum.cu
 * @brief Array cumsum GPU implementation
7
8
 */
#include <dgl/array.h>
sangwzh's avatar
sangwzh committed
9
#include "../../../include/dgl/array.h"
10

sangwzh's avatar
sangwzh committed
11
#include <hipcub/hipcub.hpp>
12

13
#include "../../runtime/cuda/cuda_common.h"
sangwzh's avatar
sangwzh committed
14
#include "utils.h"
15
16
17
18
19
20

namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {

21
template <DGLDeviceType XPU, typename IdType>
22
23
24
IdArray CumSum(IdArray array, bool prepend_zero) {
  const int64_t len = array.NumElements();
  if (len == 0)
25
26
    return !prepend_zero ? array
                         : aten::Full(0, 1, array->dtype.bits, array->ctx);
27

28
  auto device = runtime::DeviceAPI::Get(array->ctx);
sangwzh's avatar
sangwzh committed
29
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
30
31
32
33
34
35
36
37
38
39
40
41
  const IdType* in_d = array.Ptr<IdType>();
  IdArray ret;
  IdType* out_d = nullptr;
  if (prepend_zero) {
    ret = aten::Full(0, len + 1, array->dtype.bits, array->ctx);
    out_d = ret.Ptr<IdType>() + 1;
  } else {
    ret = aten::NewIdArray(len, array->ctx, array->dtype.bits);
    out_d = ret.Ptr<IdType>();
  }
  // Allocate workspace
  size_t workspace_size = 0;
sangwzh's avatar
sangwzh committed
42
  CUDA_CALL(hipcub::DeviceScan::InclusiveSum(
43
      nullptr, workspace_size, in_d, out_d, len, stream));
44
45
46
  void* workspace = device->AllocWorkspace(array->ctx, workspace_size);

  // Compute cumsum
sangwzh's avatar
sangwzh committed
47
  CUDA_CALL(hipcub::DeviceScan::InclusiveSum(
48
      workspace, workspace_size, in_d, out_d, len, stream));
49
50

  device->FreeWorkspace(array->ctx, workspace);
sangwzh's avatar
sangwzh committed
51
  std::cout << "cuda ret : " << ret << std::endl;
52
53
54
  return ret;
}

55
56
template IdArray CumSum<kDGLCUDA, int32_t>(IdArray, bool);
template IdArray CumSum<kDGLCUDA, int64_t>(IdArray, bool);
57
58
59
60

}  // namespace impl
}  // namespace aten
}  // namespace dgl