"vscode:/vscode.git/clone" did not exist on "9211d34fb8b53f2f83b6f35f601a64822f9a0f0c"
cuda_common.h 7.98 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2017 by Contributors
4
5
 * @file cuda_common.h
 * @brief Common utilities for CUDA
6
7
8
9
 */
#ifndef DGL_RUNTIME_CUDA_CUDA_COMMON_H_
#define DGL_RUNTIME_CUDA_CUDA_COMMON_H_

sangwzh's avatar
sangwzh committed
10
11
12
13
#include <hipblas/hipblas.h>
#include <hip/hip_runtime.h>
#include <hiprand/hiprand.h>
#include <hipsparse/hipsparse.h>
14
#include <dgl/runtime/packed_func.h>
15

16
#include <memory>
17
#include <string>
18

19
20
21
22
23
#include "../workspace_pool.h"

namespace dgl {
namespace runtime {

24
25
26
27
28
/*
  How to use this class to get a nonblocking thrust execution policy that uses
  DGL's memory pool and the current cuda stream

  runtime::CUDAWorkspaceAllocator allocator(ctx);
sangwzh's avatar
sangwzh committed
29
30
  const auto stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
  const auto exec_policy = thrust::hip::par_nosync(allocator).on(stream);
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

  now, one can pass exec_policy to thrust functions

  to get an integer array of size 1000 whose lifetime is managed by unique_ptr,
  use: auto int_array = allocator.alloc_unique<int>(1000); int_array.get() gives
  the raw pointer.
*/
class CUDAWorkspaceAllocator {
  DGLContext ctx;

 public:
  typedef char value_type;

  void operator()(void* ptr) const {
    runtime::DeviceAPI::Get(ctx)->FreeWorkspace(ctx, ptr);
  }

  explicit CUDAWorkspaceAllocator(DGLContext ctx) : ctx(ctx) {}

  CUDAWorkspaceAllocator& operator=(const CUDAWorkspaceAllocator&) = default;

  template <typename T>
  std::unique_ptr<T, CUDAWorkspaceAllocator> alloc_unique(
      std::size_t size) const {
    return std::unique_ptr<T, CUDAWorkspaceAllocator>(
        reinterpret_cast<T*>(runtime::DeviceAPI::Get(ctx)->AllocWorkspace(
            ctx, sizeof(T) * size)),
        *this);
  }

  char* allocate(std::ptrdiff_t size) const {
    return reinterpret_cast<char*>(
        runtime::DeviceAPI::Get(ctx)->AllocWorkspace(ctx, size));
  }

  void deallocate(char* ptr, std::size_t) const {
    runtime::DeviceAPI::Get(ctx)->FreeWorkspace(ctx, ptr);
  }
};

71
72
template <typename T>
inline bool is_zero(T size) {
73
  return size == 0;
74
75
76
77
}

template <>
inline bool is_zero<dim3>(dim3 size) {
78
  return size.x == 0 || size.y == 0 || size.z == 0;
79
80
}

81
82
#define CUDA_DRIVER_CALL(x)                                             \
  {                                                                     \
sangwzh's avatar
sangwzh committed
83
84
    hipError_t result = x;                                                \
    if (result != hipSuccess && result != hipErrorDeinitialized) { \
85
      const char* msg;                                                  \
sangwzh's avatar
sangwzh committed
86
      hipGetErrorName(result, &msg);                                     \
87
      LOG(FATAL) << "CUDAError: " #x " failed with error: " << msg;     \
88
89
90
    }                                                                   \
  }

91
92
#define CUDA_CALL(func)                                      \
  {                                                          \
sangwzh's avatar
sangwzh committed
93
94
95
    hipError_t e = (func);                                  \
    CHECK(e == hipSuccess || e == hipErrorDeinitialized) \
        << "CUDA: " << hipGetErrorString(e);                \
96
97
  }

98
99
100
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...)            \
  {                                                                           \
    if (!dgl::runtime::is_zero((nblks)) && !dgl::runtime::is_zero((nthrs))) { \
sangwzh's avatar
sangwzh committed
101
102
103
104
     hipLaunchKernelGGL(( (kernel)), dim3((nblks)), dim3((nthrs)), (shmem), (stream), __VA_ARGS__);         \
      hipError_t e = hipGetLastError();                                     \
      CHECK(e == hipSuccess || e == hipErrorDeinitialized)                \
          << "CUDA kernel launch error: " << hipGetErrorString(e);           \
105
    }                                                                         \
106
107
  }

108
109
#define CUSPARSE_CALL(func)                                         \
  {                                                                 \
sangwzh's avatar
sangwzh committed
110
111
    hipsparseStatus_t e = (func);                                    \
    CHECK(e == HIPSPARSE_STATUS_SUCCESS) << "CUSPARSE ERROR: " << e; \
112
113
  }

114
115
#define CUBLAS_CALL(func)                                       \
  {                                                             \
sangwzh's avatar
sangwzh committed
116
117
    hipblasStatus_t e = (func);                                  \
    CHECK(e == HIPBLAS_STATUS_SUCCESS) << "CUBLAS ERROR: " << e; \
118
119
  }

120
121
#define CURAND_CALL(func)                                                      \
  {                                                                            \
sangwzh's avatar
sangwzh committed
122
123
    hiprandStatus_t e = (func);                                                 \
    CHECK(e == HIPRAND_STATUS_SUCCESS)                                          \
124
125
126
        << "CURAND Error: " << dgl::runtime::curandGetErrorString(e) << " at " \
        << __FILE__ << ":" << __LINE__;                                        \
  }
127

sangwzh's avatar
sangwzh committed
128
inline const char* curandGetErrorString(hiprandStatus_t error) {
129
  switch (error) {
sangwzh's avatar
sangwzh committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    case HIPRAND_STATUS_SUCCESS:
      return "HIPRAND_STATUS_SUCCESS";
    case HIPRAND_STATUS_VERSION_MISMATCH:
      return "HIPRAND_STATUS_VERSION_MISMATCH";
    case HIPRAND_STATUS_NOT_INITIALIZED:
      return "HIPRAND_STATUS_NOT_INITIALIZED";
    case HIPRAND_STATUS_ALLOCATION_FAILED:
      return "HIPRAND_STATUS_ALLOCATION_FAILED";
    case HIPRAND_STATUS_TYPE_ERROR:
      return "HIPRAND_STATUS_TYPE_ERROR";
    case HIPRAND_STATUS_OUT_OF_RANGE:
      return "HIPRAND_STATUS_OUT_OF_RANGE";
    case HIPRAND_STATUS_LENGTH_NOT_MULTIPLE:
      return "HIPRAND_STATUS_LENGTH_NOT_MULTIPLE";
    case HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED:
      return "HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED";
    case HIPRAND_STATUS_LAUNCH_FAILURE:
      return "HIPRAND_STATUS_LAUNCH_FAILURE";
    case HIPRAND_STATUS_PREEXISTING_FAILURE:
      return "HIPRAND_STATUS_PREEXISTING_FAILURE";
    case HIPRAND_STATUS_INITIALIZATION_FAILED:
      return "HIPRAND_STATUS_INITIALIZATION_FAILED";
    case HIPRAND_STATUS_ARCH_MISMATCH:
      return "HIPRAND_STATUS_ARCH_MISMATCH";
    case HIPRAND_STATUS_INTERNAL_ERROR:
      return "HIPRAND_STATUS_INTERNAL_ERROR";
156
157
  }
  // To suppress compiler warning.
sangwzh's avatar
sangwzh committed
158
  return "Unrecognized hiprand error string";
159
160
}

161
/**
sangwzh's avatar
sangwzh committed
162
 * @brief Cast data type to hipDataType.
163
164
165
 */
template <typename T>
struct cuda_dtype {
sangwzh's avatar
sangwzh committed
166
  static constexpr hipDataType value = HIP_R_32F;
167
168
169
};

template <>
170
struct cuda_dtype<__half> {
sangwzh's avatar
sangwzh committed
171
  static constexpr hipDataType value = HIP_R_16F;
172
173
};

174
175
#if BF16_ENABLED
template <>
sangwzh's avatar
sangwzh committed
176
177
struct cuda_dtype<__hip_bfloat16> {
  static constexpr hipDataType value = HIP_R_16BF;
178
179
180
};
#endif  // BF16_ENABLED

181
182
template <>
struct cuda_dtype<float> {
sangwzh's avatar
sangwzh committed
183
  static constexpr hipDataType value = HIP_R_32F;
184
185
186
187
};

template <>
struct cuda_dtype<double> {
sangwzh's avatar
sangwzh committed
188
  static constexpr hipDataType value = HIP_R_64F;
189
190
};

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
/*
 * \brief Accumulator type for SpMM.
 */
template <typename T>
struct accum_dtype {
  typedef float type;
};

template <>
struct accum_dtype<__half> {
  typedef float type;
};

#if BF16_ENABLED
template <>
sangwzh's avatar
sangwzh committed
206
struct accum_dtype<__hip_bfloat16> {
207
208
209
210
211
212
213
214
215
216
217
218
219
220
  typedef float type;
};
#endif  // BF16_ENABLED

template <>
struct accum_dtype<float> {
  typedef float type;
};

template <>
struct accum_dtype<double> {
  typedef double type;
};

sangwzh's avatar
sangwzh committed
221
#if DTKRT_VERSION >= 11000
222
/**
sangwzh's avatar
sangwzh committed
223
 * @brief Cast index data type to hipsparseIndexType_t.
224
225
226
 */
template <typename T>
struct cusparse_idtype {
sangwzh's avatar
sangwzh committed
227
  static constexpr hipsparseIndexType_t value = HIPSPARSE_INDEX_32I;
228
229
230
231
};

template <>
struct cusparse_idtype<int32_t> {
sangwzh's avatar
sangwzh committed
232
  static constexpr hipsparseIndexType_t value = HIPSPARSE_INDEX_32I;
233
234
235
236
};

template <>
struct cusparse_idtype<int64_t> {
sangwzh's avatar
sangwzh committed
237
  static constexpr hipsparseIndexType_t value = HIPSPARSE_INDEX_64I;
238
};
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
239
#endif
240

241
/** @brief Thread local workspace */
242
243
class CUDAThreadEntry {
 public:
244
  /** @brief The cusparse handler */
sangwzh's avatar
sangwzh committed
245
  hipsparseHandle_t cusparse_handle{nullptr};
246
  /** @brief The cublas handler */
sangwzh's avatar
sangwzh committed
247
  hipblasHandle_t cublas_handle{nullptr};
248
  /** @brief thread local pool*/
249
  WorkspacePool pool;
250
  /** @brief constructor */
251
252
253
254
  CUDAThreadEntry();
  // get the threadlocal workspace
  static CUDAThreadEntry* ThreadLocal();
};
255

256
/** @brief Get the current CUDA stream */
sangwzh's avatar
sangwzh committed
257
hipStream_t getCurrentHIPStreamMasqueradingAsCUDA();
258
259
260
}  // namespace runtime
}  // namespace dgl
#endif  // DGL_RUNTIME_CUDA_CUDA_COMMON_H_