atomic.cuh 10.1 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file array/cuda/atomic.cuh
 * @brief Atomic functions
5
 */
6
7
#ifndef DGL_ARRAY_CUDA_ATOMIC_CUH_
#define DGL_ARRAY_CUDA_ATOMIC_CUH_
8
9

#include <cuda_runtime.h>
10

11
#include <cassert>
12

13
#include "bf16.cuh"
14
#include "fp16.cuh"
15

16
17
18
#if __CUDA_ARCH__ >= 600
#include <cuda_fp16.h>
#endif
19
20
21
22
23
24

namespace dgl {
namespace aten {
namespace cuda {

// Type trait for selecting code type
25
26
template <int Bytes>
struct Code {};
27

28
29
template <>
struct Code<2> {
30
  typedef unsigned short int Type;  // NOLINT
31
32
};

33
34
template <>
struct Code<4> {
35
  typedef unsigned int Type;  // NOLINT
36
37
};

38
39
template <>
struct Code<8> {
40
  typedef unsigned long long int Type;  // NOLINT
41
42
43
};

// Helper class for converting to/from atomicCAS compatible types.
44
45
template <typename T>
struct Cast {
46
47
48
49
50
51
52
53
54
  typedef typename Code<sizeof(T)>::Type Type;
  static __device__ __forceinline__ Type Encode(T val) {
    return static_cast<Type>(val);
  }
  static __device__ __forceinline__ T Decode(Type code) {
    return static_cast<T>(code);
  }
};

55
56
template <>
struct Cast<half> {
57
58
59
60
61
62
63
64
  typedef Code<sizeof(half)>::Type Type;
  static __device__ __forceinline__ Type Encode(half val) {
    return __half_as_ushort(val);
  }
  static __device__ __forceinline__ half Decode(Type code) {
    return __ushort_as_half(code);
  }
};
65
66

#if BF16_ENABLED
67
68
template <>
struct Cast<__nv_bfloat16> {
69
70
71
72
73
  typedef Code<sizeof(__nv_bfloat16)>::Type Type;
  static __device__ __forceinline__ Type Encode(__nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
    return __bfloat16_as_ushort(val);
#else
74
75
    printf(
        "Atomic operations are not supported for bfloat16 (BF16) "
76
77
78
79
80
81
82
83
84
        "on GPUs with compute capability less than 8.0.\n");
    __trap();
    return static_cast<Type>(0);
#endif
  }
  static __device__ __forceinline__ __nv_bfloat16 Decode(Type code) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
    return __ushort_as_bfloat16(code);
#else
85
86
    printf(
        "Atomic operations are not supported for bfloat16 (BF16) "
87
88
89
        "on GPUs with compute capability less than 8.0.\n");
    __trap();
    return static_cast<__nv_bfloat16>(0.0f);
90
#endif
91
92
93
  }
};
#endif  // BF16_ENABLED
94

95
96
template <>
struct Cast<float> {
97
98
99
100
101
102
103
104
105
  typedef Code<sizeof(float)>::Type Type;
  static __device__ __forceinline__ Type Encode(float val) {
    return __float_as_uint(val);
  }
  static __device__ __forceinline__ float Decode(Type code) {
    return __uint_as_float(code);
  }
};

106
107
template <>
struct Cast<double> {
108
109
110
111
112
113
114
115
116
  typedef Code<sizeof(double)>::Type Type;
  static __device__ __forceinline__ Type Encode(double val) {
    return __double_as_longlong(val);
  }
  static __device__ __forceinline__ double Decode(Type code) {
    return __longlong_as_double(code);
  }
};

117
static __device__ __forceinline__ unsigned short int atomicCASshort(  // NOLINT
118
    unsigned short int* address,                                      // NOLINT
119
120
    unsigned short int compare,                                       // NOLINT
    unsigned short int val) {                                         // NOLINT
121
  static_assert(CUDART_VERSION >= 10000, "Requires at least CUDA 10");
122
123
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
  return atomicCAS(address, compare, val);
124
125
126
127
#else
  (void)address;
  (void)compare;
  (void)val;
128
129
  printf(
      "Atomic operations are not supported for half precision (FP16) "
130
131
      "on this GPU.\n");
  __trap();
132
  return val;
133
#endif  // (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
134
135
}

136
137
138
139
140
141
142
143
144
145
146
147
148
149
#define DEFINE_ATOMIC(NAME)                                   \
  template <typename T>                                       \
  __device__ __forceinline__ T Atomic##NAME(T* addr, T val) { \
    typedef typename Cast<T>::Type CT;                        \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);             \
    CT old = *addr_as_ui;                                     \
    CT assumed = old;                                         \
    do {                                                      \
      assumed = old;                                          \
      old = atomicCAS(                                        \
          addr_as_ui, assumed,                                \
          Cast<T>::Encode(OP(val, Cast<T>::Decode(old))));    \
    } while (assumed != old);                                 \
    return Cast<T>::Decode(old);                              \
150
151
  }

152
153
154
155
156
157
158
159
160
161
162
163
#define DEFINE_ATOMIC_16BIT(NAME, dtype)                           \
  template <>                                                      \
  __device__ __forceinline__ dtype Atomic##NAME<dtype>(            \
      dtype * addr, dtype val) {                                   \
    typedef uint16_t CT;                                           \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                  \
    CT old = *addr_as_ui;                                          \
    CT assumed = old;                                              \
    do {                                                           \
      assumed = old;                                               \
      old = atomicCASshort(                                        \
          addr_as_ui, assumed,                                     \
164
          Cast<dtype>::Encode(OP(val, Cast<dtype>::Decode(old)))); \
165
166
    } while (assumed != old);                                      \
    return Cast<dtype>::Decode(old);                               \
167
168
  }

169
170
#define OP(a, b) max(a, b)
DEFINE_ATOMIC(Max)
171
172
173
174
DEFINE_ATOMIC_16BIT(Max, half)
#if BF16_ENABLED
DEFINE_ATOMIC_16BIT(Max, __nv_bfloat16)
#endif  // BF16_ENABLED
175
176
177
178
#undef OP

#define OP(a, b) min(a, b)
DEFINE_ATOMIC(Min)
179
180
181
182
DEFINE_ATOMIC_16BIT(Min, half)
#if BF16_ENABLED
DEFINE_ATOMIC_16BIT(Min, __nv_bfloat16)
#endif  // BF16_ENABLED
183
184
185
186
187
188
#undef OP

#define OP(a, b) a + b
DEFINE_ATOMIC(Add)
#undef OP

189
/**
190
191
192
193
194
195
196
197
198
199
200
201
202
 * @brief Performs an atomic compare-and-swap on 64 bit integers. That is,
 * it the word `old` at the memory location `address`, computes
 * `(old == compare ? val : old)` , and stores the result back to memory at
 * the same address.
 *
 * @param address The address to perform the atomic operation on.
 * @param compare The value to compare to.
 * @param val The new value to conditionally store.
 *
 * @return The old value at the address.
 */
inline __device__ int64_t
AtomicCAS(int64_t* const address, const int64_t compare, const int64_t val) {
203
  // match the type of "::atomicCAS", so ignore lint warning
204
  using Type = unsigned long long int;  // NOLINT
205
206
207

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

208
209
210
  return atomicCAS(
      reinterpret_cast<Type*>(address), static_cast<Type>(compare),
      static_cast<Type>(val));
211
212
213
}

/**
214
215
216
217
218
219
220
221
222
223
224
225
226
 * @brief Performs an atomic compare-and-swap on 32 bit integers. That is,
 * it the word `old` at the memory location `address`, computes
 * `(old == compare ? val : old)` , and stores the result back to memory at
 * the same address.
 *
 * @param address The address to perform the atomic operation on.
 * @param compare The value to compare to.
 * @param val The new value to conditionally store.
 *
 * @return The old value at the address.
 */
inline __device__ int32_t
AtomicCAS(int32_t* const address, const int32_t compare, const int32_t val) {
227
  // match the type of "::atomicCAS", so ignore lint warning
228
  using Type = int;  // NOLINT
229
230
231

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

232
233
234
  return atomicCAS(
      reinterpret_cast<Type*>(address), static_cast<Type>(compare),
      static_cast<Type>(val));
235
236
}

237
inline __device__ int64_t AtomicMax(int64_t* const address, const int64_t val) {
238
  // match the type of "::atomicCAS", so ignore lint warning
239
  using Type = unsigned long long int;  // NOLINT
240
241
242

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

243
  return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
244
245
}

246
inline __device__ int32_t AtomicMax(int32_t* const address, const int32_t val) {
247
  // match the type of "::atomicCAS", so ignore lint warning
248
  using Type = int;  // NOLINT
249
250
251

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

252
  return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
253
254
}

255
256
template <>
__device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
257
#if __CUDA_ARCH__ >= 200
258
  return atomicAdd(addr, val);
259
#else
260
261
262
263
264
265
266
  typedef float T;
  typedef typename Cast<T>::Type CT;
  CT* addr_as_ui = reinterpret_cast<CT*>(addr);
  CT old = *addr_as_ui;
  CT assumed = old;
  do {
    assumed = old;
267
268
    old = atomicCAS(
        addr_as_ui, assumed, Cast<T>::Encode(Cast<T>::Decode(old) + val));
269
270
  } while (assumed != old);
  return Cast<T>::Decode(old);
271
#endif  // __CUDA_ARCH__
272
}
273
274
275

template <>
__device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
276
#if __CUDA_ARCH__ >= 600
277
  return atomicAdd(addr, val);
278
#else
279
280
281
282
283
284
285
  typedef double T;
  typedef typename Cast<T>::Type CT;
  CT* addr_as_ui = reinterpret_cast<CT*>(addr);
  CT old = *addr_as_ui;
  CT assumed = old;
  do {
    assumed = old;
286
287
    old = atomicCAS(
        addr_as_ui, assumed, Cast<T>::Encode(Cast<T>::Decode(old) + val));
288
289
  } while (assumed != old);
  return Cast<T>::Decode(old);
290
#endif
291
}
292
293
294

#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000
template <>
295
__device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
296
// make sure we have half support
297
298
#if __CUDA_ARCH__ >= 700
  return atomicAdd(addr, val);
299
#else
300
301
  (void)addr;
  (void)val;
302
303
  printf(
      "Atomic operations are not supported for half precision (FP16) "
304
305
306
      "on this GPU.\n");
  __trap();
  return val;
307
#endif  // __CUDA_ARCH__ >= 700
308
309
}
#endif  // defined(CUDART_VERSION) && CUDART_VERSION >= 10000
310
311
312

#if BF16_ENABLED
template <>
313
314
__device__ __forceinline__ __nv_bfloat16
AtomicAdd<__nv_bfloat16>(__nv_bfloat16* addr, __nv_bfloat16 val) {
315
316
317
318
319
320
// make sure we have bfloat16 support
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  return atomicAdd(addr, val);
#else
  (void)addr;
  (void)val;
321
322
  printf(
      "Atomic operations are not supported for bfloat16 (BF16) "
323
324
325
326
327
328
      "on GPUs with compute capability less than 8.0.\n");
  __trap();
  return val;
#endif  // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
}
#endif  // BF16_ENABLED
329
330
331
332
333

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

334
#endif  // DGL_ARRAY_CUDA_ATOMIC_CUH_