atomic.cuh 12.3 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2019 by Contributors
4
5
 * @file array/cuda/atomic.cuh
 * @brief Atomic functions
6
 */
7
8
#ifndef DGL_ARRAY_CUDA_ATOMIC_CUH_
#define DGL_ARRAY_CUDA_ATOMIC_CUH_
9

sangwzh's avatar
sangwzh committed
10
#include <hip/hip_runtime.h>
11

12
#include <cassert>
13
14
#include <cstdint>
#include <cstdio>
15

16
#include "bf16.cuh"
17
#include "fp16.cuh"
18

sangwzh's avatar
sangwzh committed
19
20
#if __HIPCC__ 
#include <hip/hip_fp16.h>
21
#endif
22
23
24
25
26
27

namespace dgl {
namespace aten {
namespace cuda {

// Type trait for selecting code type
28
29
template <int Bytes>
struct Code {};
30

31
32
template <>
struct Code<2> {
33
  typedef unsigned short int Type;  // NOLINT
34
35
};

36
37
template <>
struct Code<4> {
38
  typedef unsigned int Type;  // NOLINT
39
40
};

41
42
template <>
struct Code<8> {
43
  typedef unsigned long long int Type;  // NOLINT
44
45
46
};

// Helper class for converting to/from atomicCAS compatible types.
47
48
template <typename T>
struct Cast {
49
50
51
52
53
54
55
56
57
  typedef typename Code<sizeof(T)>::Type Type;
  static __device__ __forceinline__ Type Encode(T val) {
    return static_cast<Type>(val);
  }
  static __device__ __forceinline__ T Decode(Type code) {
    return static_cast<T>(code);
  }
};

58
59
template <>
struct Cast<half> {
sangwzh's avatar
sangwzh committed
60
61
  typedef half Type;
  static __host__ __device__ __forceinline__ Type Encode(half val) {
62
63
    return __half_as_ushort(val);
  }
sangwzh's avatar
sangwzh committed
64
  static __host__ __device__ __forceinline__ half Decode(Type code) {
65
66
67
    return __ushort_as_half(code);
  }
};
68
69

#if BF16_ENABLED
70
template <>
sangwzh's avatar
sangwzh committed
71
72
73
74
struct Cast<__hip_bfloat16> {
  typedef __hip_bfloat16 Type;
  static __host__ __device__ __forceinline__ Type Encode(__hip_bfloat16 val) {
#if defined(__HIP_DEVICE_COMPILE__)
75
76
    return __bfloat16_as_ushort(val);
#else
77
78
    printf(
        "Atomic operations are not supported for bfloat16 (BF16) "
79
        "on GPUs with compute capability less than 8.0.\n");
sangwzh's avatar
sangwzh committed
80
    // //__trap();
81
82
83
    return static_cast<Type>(0);
#endif
  }
sangwzh's avatar
sangwzh committed
84
85
  static __host__ __device__ __forceinline__ __hip_bfloat16 Decode(Type code) {
#if defined(__HIP_DEVICE_COMPILE__)
86
87
    return __ushort_as_bfloat16(code);
#else
88
89
    printf(
        "Atomic operations are not supported for bfloat16 (BF16) "
90
        "on GPUs with compute capability less than 8.0.\n");
sangwzh's avatar
sangwzh committed
91
92
    //__trap();
    return static_cast<__hip_bfloat16>(0.0f);
93
#endif
94
95
96
  }
};
#endif  // BF16_ENABLED
97

98
99
template <>
struct Cast<float> {
100
101
102
103
104
105
106
107
108
  typedef Code<sizeof(float)>::Type Type;
  static __device__ __forceinline__ Type Encode(float val) {
    return __float_as_uint(val);
  }
  static __device__ __forceinline__ float Decode(Type code) {
    return __uint_as_float(code);
  }
};

109
110
template <>
struct Cast<double> {
111
112
113
114
115
116
117
118
119
  typedef Code<sizeof(double)>::Type Type;
  static __device__ __forceinline__ Type Encode(double val) {
    return __double_as_longlong(val);
  }
  static __device__ __forceinline__ double Decode(Type code) {
    return __longlong_as_double(code);
  }
};

sangwzh's avatar
sangwzh committed
120
static __host__ __device__ __forceinline__ unsigned short int atomicCASshort(  // NOLINT
121
    unsigned short int* address,                                      // NOLINT
122
123
    unsigned short int compare,                                       // NOLINT
    unsigned short int val) {                                         // NOLINT
sangwzh's avatar
sangwzh committed
124
125
  static_assert(DTKRT_VERSION >= 10000, "Requires at least CUDA 10");
#if defined(__HIP_DEVICE_COMPILE__) && 0
126
  return atomicCAS(address, compare, val);
127
128
129
130
#else
  (void)address;
  (void)compare;
  (void)val;
131
132
  printf(
      "Atomic operations are not supported for half precision (FP16) "
133
      "on this GPU.\n");
sangwzh's avatar
sangwzh committed
134
  abort();
135
  return val;
sangwzh's avatar
sangwzh committed
136
#endif  // (defined(__HIP_DEVICE_COMPILE__) 
137
138
}

139
140
141
142
143
144
145
146
147
148
149
150
151
152
#define DEFINE_ATOMIC(NAME)                                   \
  template <typename T>                                       \
  __device__ __forceinline__ T Atomic##NAME(T* addr, T val) { \
    typedef typename Cast<T>::Type CT;                        \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);             \
    CT old = *addr_as_ui;                                     \
    CT assumed = old;                                         \
    do {                                                      \
      assumed = old;                                          \
      old = atomicCAS(                                        \
          addr_as_ui, assumed,                                \
          Cast<T>::Encode(OP(val, Cast<T>::Decode(old))));    \
    } while (assumed != old);                                 \
    return Cast<T>::Decode(old);                              \
153
154
  }

155
156
157
158
159
160
161
162
163
164
165
166
#define DEFINE_ATOMIC_16BIT(NAME, dtype)                           \
  template <>                                                      \
  __device__ __forceinline__ dtype Atomic##NAME<dtype>(            \
      dtype * addr, dtype val) {                                   \
    typedef uint16_t CT;                                           \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                  \
    CT old = *addr_as_ui;                                          \
    CT assumed = old;                                              \
    do {                                                           \
      assumed = old;                                               \
      old = atomicCASshort(                                        \
          addr_as_ui, assumed,                                     \
167
          Cast<dtype>::Encode(OP(val, Cast<dtype>::Decode(old)))); \
168
169
    } while (assumed != old);                                      \
    return Cast<dtype>::Decode(old);                               \
170
171
  }

sangwzh's avatar
sangwzh committed
172
#define DEFINE_ATOMIC_16BIT_MAX(NAME, dtype)                           \
sangwzh's avatar
sangwzh committed
173
174
175
176
177
178
179
180
181
182
183
  template <>                                                      \
  __device__ __forceinline__ dtype Atomic##NAME<dtype>(            \
      dtype * addr, dtype val) {                                   \
    typedef uint16_t CT;                                           \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                  \
    CT old = *addr_as_ui;                                          \
    CT assumed = old;                                              \
    do {                                                           \
      assumed = old;                                               \
      old = atomicCASshort(                                        \
          addr_as_ui, assumed,                                     \
sangwzh's avatar
sangwzh committed
184
          Cast<dtype>::Encode(dtype(max((float)val, (float)dtype(old))))); \
sangwzh's avatar
sangwzh committed
185
186
187
188
    } while (assumed != old);                                      \
    return Cast<dtype>::Decode(old);                               \
  }

sangwzh's avatar
sangwzh committed
189
#define DEFINE_ATOMIC_16BIT_MIN(NAME, dtype)                           \
sangwzh's avatar
sangwzh committed
190
191
192
193
194
195
196
197
198
199
200
  template <>                                                      \
  __device__ __forceinline__ dtype Atomic##NAME<dtype>(            \
      dtype * addr, dtype val) {                                   \
    typedef uint16_t CT;                                           \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                  \
    CT old = *addr_as_ui;                                          \
    CT assumed = old;                                              \
    do {                                                           \
      assumed = old;                                               \
      old = atomicCASshort(                                        \
          addr_as_ui, assumed,                                     \
sangwzh's avatar
sangwzh committed
201
          Cast<dtype>::Encode(dtype(min((float)val,(float)old)))); \
sangwzh's avatar
sangwzh committed
202
203
204
205
    } while (assumed != old);                                      \
    return Cast<dtype>::Decode(old);                               \
  }

sangwzh's avatar
sangwzh committed
206
#define OP(a, b) max(a, b)
207
DEFINE_ATOMIC(Max)
sangwzh's avatar
sangwzh committed
208
DEFINE_ATOMIC_16BIT_MAX(Max, half)
209
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
210
211
#define OP_BF(a, b) max_bf((float)a, (float)b)
DEFINE_ATOMIC_16BIT_MAX(Max, __hip_bfloat16)
212
#endif  // BF16_ENABLED
213
214
#undef OP

sangwzh's avatar
sangwzh committed
215
#define OP(a, b) min(a, b)
216
DEFINE_ATOMIC(Min)
sangwzh's avatar
sangwzh committed
217
DEFINE_ATOMIC_16BIT_MIN(Min, half)
218
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
219
DEFINE_ATOMIC_16BIT_MIN(Min, __hip_bfloat16)
220
#endif  // BF16_ENABLED
221
222
223
224
225
226
#undef OP

#define OP(a, b) a + b
DEFINE_ATOMIC(Add)
#undef OP

227
/**
228
229
230
231
232
233
234
235
236
237
238
239
240
 * @brief Performs an atomic compare-and-swap on 64 bit integers. That is,
 * it the word `old` at the memory location `address`, computes
 * `(old == compare ? val : old)` , and stores the result back to memory at
 * the same address.
 *
 * @param address The address to perform the atomic operation on.
 * @param compare The value to compare to.
 * @param val The new value to conditionally store.
 *
 * @return The old value at the address.
 */
inline __device__ int64_t
AtomicCAS(int64_t* const address, const int64_t compare, const int64_t val) {
241
  // match the type of "::atomicCAS", so ignore lint warning
242
  using Type = unsigned long long int;  // NOLINT
243
244
245

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

246
247
248
  return atomicCAS(
      reinterpret_cast<Type*>(address), static_cast<Type>(compare),
      static_cast<Type>(val));
249
250
251
}

/**
252
253
254
255
256
257
258
259
260
261
262
263
264
 * @brief Performs an atomic compare-and-swap on 32 bit integers. That is,
 * it the word `old` at the memory location `address`, computes
 * `(old == compare ? val : old)` , and stores the result back to memory at
 * the same address.
 *
 * @param address The address to perform the atomic operation on.
 * @param compare The value to compare to.
 * @param val The new value to conditionally store.
 *
 * @return The old value at the address.
 */
inline __device__ int32_t
AtomicCAS(int32_t* const address, const int32_t compare, const int32_t val) {
265
  // match the type of "::atomicCAS", so ignore lint warning
266
  using Type = int;  // NOLINT
267
268
269

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

270
271
272
  return atomicCAS(
      reinterpret_cast<Type*>(address), static_cast<Type>(compare),
      static_cast<Type>(val));
273
274
}

275
inline __device__ int64_t AtomicMax(int64_t* const address, const int64_t val) {
276
  // match the type of "::atomicCAS", so ignore lint warning
277
  using Type = unsigned long long int;  // NOLINT
278
279
280

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

281
  return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
282
283
}

284
inline __device__ int32_t AtomicMax(int32_t* const address, const int32_t val) {
285
  // match the type of "::atomicCAS", so ignore lint warning
286
  using Type = int;  // NOLINT
287
288
289

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

290
  return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
291
292
}

293
294
template <>
__device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
sangwzh's avatar
sangwzh committed
295
#if __HIP_DEVICE_COMPILE__ 
296
  return atomicAdd(addr, val);
297
#else
298
299
300
301
302
303
304
  typedef float T;
  typedef typename Cast<T>::Type CT;
  CT* addr_as_ui = reinterpret_cast<CT*>(addr);
  CT old = *addr_as_ui;
  CT assumed = old;
  do {
    assumed = old;
305
306
    old = atomicCAS(
        addr_as_ui, assumed, Cast<T>::Encode(Cast<T>::Decode(old) + val));
307
308
  } while (assumed != old);
  return Cast<T>::Decode(old);
sangwzh's avatar
sangwzh committed
309
#endif  // __HIP_DEVICE_COMPILE__
310
}
311
312
313

template <>
__device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
sangwzh's avatar
sangwzh committed
314
#if __HIP_DEVICE_COMPILE__ 
315
  return atomicAdd(addr, val);
316
#else
317
318
319
320
321
322
323
  typedef double T;
  typedef typename Cast<T>::Type CT;
  CT* addr_as_ui = reinterpret_cast<CT*>(addr);
  CT old = *addr_as_ui;
  CT assumed = old;
  do {
    assumed = old;
324
325
    old = atomicCAS(
        addr_as_ui, assumed, Cast<T>::Encode(Cast<T>::Decode(old) + val));
326
327
  } while (assumed != old);
  return Cast<T>::Decode(old);
328
#endif
329
}
330

sangwzh's avatar
sangwzh committed
331
#if defined(DTKRT_VERSION) && DTKRT_VERSION >= 10000
332
template <>
333
__device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
334
// make sure we have half support
sangwzh's avatar
sangwzh committed
335
#if __HIP_DEVICE_COMPILE__
336
  return atomicAdd(addr, val);
337
#else
338
339
  (void)addr;
  (void)val;
340
341
  printf(
      "Atomic operations are not supported for half precision (FP16) "
342
      "on this GPU.\n");
sangwzh's avatar
sangwzh committed
343
  // //__trap();
344
  return val;
sangwzh's avatar
sangwzh committed
345
#endif  // __HIP_DEVICE_COMPILE__ 
346
}
sangwzh's avatar
sangwzh committed
347
#endif  // defined(DTKRT_VERSION) && DTKRT_VERSION >= 10000
348
349
350

#if BF16_ENABLED
template <>
sangwzh's avatar
sangwzh committed
351
352
__device__ __forceinline__ __hip_bfloat16
AtomicAdd<__hip_bfloat16>(__hip_bfloat16* addr, __hip_bfloat16 val) {
353
// make sure we have bfloat16 support
sangwzh's avatar
sangwzh committed
354
#if defined(__HIP_DEVICE_COMPILE__) 
355
356
357
358
  return atomicAdd(addr, val);
#else
  (void)addr;
  (void)val;
359
360
  printf(
      "Atomic operations are not supported for bfloat16 (BF16) "
361
      "on GPUs with compute capability less than 8.0.\n");
sangwzh's avatar
sangwzh committed
362
  //__trap();
363
  return val;
sangwzh's avatar
sangwzh committed
364
#endif  // defined(__HIP_DEVICE_COMPILE__) 
365
366
}
#endif  // BF16_ENABLED
367
368
369
370
371

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

372
#endif  // DGL_ARRAY_CUDA_ATOMIC_CUH_