"vscode:/vscode.git/clone" did not exist on "9088c6359299978390430821c23a2cfd0cb8ffeb"
atomic.cuh 10.1 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file array/cuda/atomic.cuh
 * @brief Atomic functions
5
 */
6
7
#ifndef DGL_ARRAY_CUDA_ATOMIC_CUH_
#define DGL_ARRAY_CUDA_ATOMIC_CUH_
8
9

#include <cuda_runtime.h>
10

11
#include <cassert>
12
13
#include <cstdint>
#include <cstdio>
14

15
#include "bf16.cuh"
16
#include "fp16.cuh"
17

18
19
20
#if __CUDA_ARCH__ >= 600
#include <cuda_fp16.h>
#endif
21
22
23
24
25
26

namespace dgl {
namespace aten {
namespace cuda {

// Type trait for selecting code type
27
28
template <int Bytes>
struct Code {};
29

30
31
template <>
struct Code<2> {
32
  typedef unsigned short int Type;  // NOLINT
33
34
};

35
36
template <>
struct Code<4> {
37
  typedef unsigned int Type;  // NOLINT
38
39
};

40
41
template <>
struct Code<8> {
42
  typedef unsigned long long int Type;  // NOLINT
43
44
45
};

// Helper class for converting to/from atomicCAS compatible types.
46
47
template <typename T>
struct Cast {
48
49
50
51
52
53
54
55
56
  typedef typename Code<sizeof(T)>::Type Type;
  static __device__ __forceinline__ Type Encode(T val) {
    return static_cast<Type>(val);
  }
  static __device__ __forceinline__ T Decode(Type code) {
    return static_cast<T>(code);
  }
};

57
58
template <>
struct Cast<half> {
59
60
61
62
63
64
65
66
  typedef Code<sizeof(half)>::Type Type;
  static __device__ __forceinline__ Type Encode(half val) {
    return __half_as_ushort(val);
  }
  static __device__ __forceinline__ half Decode(Type code) {
    return __ushort_as_half(code);
  }
};
67
68

#if BF16_ENABLED
69
70
template <>
struct Cast<__nv_bfloat16> {
71
72
73
74
75
  typedef Code<sizeof(__nv_bfloat16)>::Type Type;
  static __device__ __forceinline__ Type Encode(__nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
    return __bfloat16_as_ushort(val);
#else
76
77
    printf(
        "Atomic operations are not supported for bfloat16 (BF16) "
78
79
80
81
82
83
84
85
86
        "on GPUs with compute capability less than 8.0.\n");
    __trap();
    return static_cast<Type>(0);
#endif
  }
  static __device__ __forceinline__ __nv_bfloat16 Decode(Type code) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
    return __ushort_as_bfloat16(code);
#else
87
88
    printf(
        "Atomic operations are not supported for bfloat16 (BF16) "
89
90
91
        "on GPUs with compute capability less than 8.0.\n");
    __trap();
    return static_cast<__nv_bfloat16>(0.0f);
92
#endif
93
94
95
  }
};
#endif  // BF16_ENABLED
96

97
98
template <>
struct Cast<float> {
99
100
101
102
103
104
105
106
107
  typedef Code<sizeof(float)>::Type Type;
  static __device__ __forceinline__ Type Encode(float val) {
    return __float_as_uint(val);
  }
  static __device__ __forceinline__ float Decode(Type code) {
    return __uint_as_float(code);
  }
};

108
109
template <>
struct Cast<double> {
110
111
112
113
114
115
116
117
118
  typedef Code<sizeof(double)>::Type Type;
  static __device__ __forceinline__ Type Encode(double val) {
    return __double_as_longlong(val);
  }
  static __device__ __forceinline__ double Decode(Type code) {
    return __longlong_as_double(code);
  }
};

119
static __device__ __forceinline__ unsigned short int atomicCASshort(  // NOLINT
120
    unsigned short int* address,                                      // NOLINT
121
122
    unsigned short int compare,                                       // NOLINT
    unsigned short int val) {                                         // NOLINT
123
  static_assert(CUDART_VERSION >= 10000, "Requires at least CUDA 10");
124
125
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
  return atomicCAS(address, compare, val);
126
127
128
129
#else
  (void)address;
  (void)compare;
  (void)val;
130
131
  printf(
      "Atomic operations are not supported for half precision (FP16) "
132
133
      "on this GPU.\n");
  __trap();
134
  return val;
135
#endif  // (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
136
137
}

138
139
140
141
142
143
144
145
146
147
148
149
150
151
#define DEFINE_ATOMIC(NAME)                                   \
  template <typename T>                                       \
  __device__ __forceinline__ T Atomic##NAME(T* addr, T val) { \
    typedef typename Cast<T>::Type CT;                        \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);             \
    CT old = *addr_as_ui;                                     \
    CT assumed = old;                                         \
    do {                                                      \
      assumed = old;                                          \
      old = atomicCAS(                                        \
          addr_as_ui, assumed,                                \
          Cast<T>::Encode(OP(val, Cast<T>::Decode(old))));    \
    } while (assumed != old);                                 \
    return Cast<T>::Decode(old);                              \
152
153
  }

154
155
156
157
158
159
160
161
162
163
164
165
#define DEFINE_ATOMIC_16BIT(NAME, dtype)                           \
  template <>                                                      \
  __device__ __forceinline__ dtype Atomic##NAME<dtype>(            \
      dtype * addr, dtype val) {                                   \
    typedef uint16_t CT;                                           \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                  \
    CT old = *addr_as_ui;                                          \
    CT assumed = old;                                              \
    do {                                                           \
      assumed = old;                                               \
      old = atomicCASshort(                                        \
          addr_as_ui, assumed,                                     \
166
          Cast<dtype>::Encode(OP(val, Cast<dtype>::Decode(old)))); \
167
168
    } while (assumed != old);                                      \
    return Cast<dtype>::Decode(old);                               \
169
170
  }

171
172
#define OP(a, b) max(a, b)
DEFINE_ATOMIC(Max)
173
174
175
176
DEFINE_ATOMIC_16BIT(Max, half)
#if BF16_ENABLED
DEFINE_ATOMIC_16BIT(Max, __nv_bfloat16)
#endif  // BF16_ENABLED
177
178
179
180
#undef OP

#define OP(a, b) min(a, b)
DEFINE_ATOMIC(Min)
181
182
183
184
DEFINE_ATOMIC_16BIT(Min, half)
#if BF16_ENABLED
DEFINE_ATOMIC_16BIT(Min, __nv_bfloat16)
#endif  // BF16_ENABLED
185
186
187
188
189
190
#undef OP

#define OP(a, b) a + b
DEFINE_ATOMIC(Add)
#undef OP

191
/**
192
193
194
195
196
197
198
199
200
201
202
203
204
 * @brief Performs an atomic compare-and-swap on 64 bit integers. That is,
 * it the word `old` at the memory location `address`, computes
 * `(old == compare ? val : old)` , and stores the result back to memory at
 * the same address.
 *
 * @param address The address to perform the atomic operation on.
 * @param compare The value to compare to.
 * @param val The new value to conditionally store.
 *
 * @return The old value at the address.
 */
inline __device__ int64_t
AtomicCAS(int64_t* const address, const int64_t compare, const int64_t val) {
205
  // match the type of "::atomicCAS", so ignore lint warning
206
  using Type = unsigned long long int;  // NOLINT
207
208
209

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

210
211
212
  return atomicCAS(
      reinterpret_cast<Type*>(address), static_cast<Type>(compare),
      static_cast<Type>(val));
213
214
215
}

/**
216
217
218
219
220
221
222
223
224
225
226
227
228
 * @brief Performs an atomic compare-and-swap on 32 bit integers. That is,
 * it the word `old` at the memory location `address`, computes
 * `(old == compare ? val : old)` , and stores the result back to memory at
 * the same address.
 *
 * @param address The address to perform the atomic operation on.
 * @param compare The value to compare to.
 * @param val The new value to conditionally store.
 *
 * @return The old value at the address.
 */
inline __device__ int32_t
AtomicCAS(int32_t* const address, const int32_t compare, const int32_t val) {
229
  // match the type of "::atomicCAS", so ignore lint warning
230
  using Type = int;  // NOLINT
231
232
233

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

234
235
236
  return atomicCAS(
      reinterpret_cast<Type*>(address), static_cast<Type>(compare),
      static_cast<Type>(val));
237
238
}

239
inline __device__ int64_t AtomicMax(int64_t* const address, const int64_t val) {
240
  // match the type of "::atomicCAS", so ignore lint warning
241
  using Type = unsigned long long int;  // NOLINT
242
243
244

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

245
  return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
246
247
}

248
inline __device__ int32_t AtomicMax(int32_t* const address, const int32_t val) {
249
  // match the type of "::atomicCAS", so ignore lint warning
250
  using Type = int;  // NOLINT
251
252
253

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

254
  return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
255
256
}

257
258
template <>
__device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
259
#if __CUDA_ARCH__ >= 200
260
  return atomicAdd(addr, val);
261
#else
262
263
264
265
266
267
268
  typedef float T;
  typedef typename Cast<T>::Type CT;
  CT* addr_as_ui = reinterpret_cast<CT*>(addr);
  CT old = *addr_as_ui;
  CT assumed = old;
  do {
    assumed = old;
269
270
    old = atomicCAS(
        addr_as_ui, assumed, Cast<T>::Encode(Cast<T>::Decode(old) + val));
271
272
  } while (assumed != old);
  return Cast<T>::Decode(old);
273
#endif  // __CUDA_ARCH__
274
}
275
276
277

template <>
__device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
278
#if __CUDA_ARCH__ >= 600
279
  return atomicAdd(addr, val);
280
#else
281
282
283
284
285
286
287
  typedef double T;
  typedef typename Cast<T>::Type CT;
  CT* addr_as_ui = reinterpret_cast<CT*>(addr);
  CT old = *addr_as_ui;
  CT assumed = old;
  do {
    assumed = old;
288
289
    old = atomicCAS(
        addr_as_ui, assumed, Cast<T>::Encode(Cast<T>::Decode(old) + val));
290
291
  } while (assumed != old);
  return Cast<T>::Decode(old);
292
#endif
293
}
294
295
296

#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000
template <>
297
__device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
298
// make sure we have half support
299
300
#if __CUDA_ARCH__ >= 700
  return atomicAdd(addr, val);
301
#else
302
303
  (void)addr;
  (void)val;
304
305
  printf(
      "Atomic operations are not supported for half precision (FP16) "
306
307
308
      "on this GPU.\n");
  __trap();
  return val;
309
#endif  // __CUDA_ARCH__ >= 700
310
311
}
#endif  // defined(CUDART_VERSION) && CUDART_VERSION >= 10000
312
313
314

#if BF16_ENABLED
template <>
315
316
__device__ __forceinline__ __nv_bfloat16
AtomicAdd<__nv_bfloat16>(__nv_bfloat16* addr, __nv_bfloat16 val) {
317
318
319
320
321
322
// make sure we have bfloat16 support
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  return atomicAdd(addr, val);
#else
  (void)addr;
  (void)val;
323
324
  printf(
      "Atomic operations are not supported for bfloat16 (BF16) "
325
326
327
328
329
330
      "on GPUs with compute capability less than 8.0.\n");
  __trap();
  return val;
#endif  // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
}
#endif  // BF16_ENABLED
331
332
333
334
335

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

336
#endif  // DGL_ARRAY_CUDA_ATOMIC_CUH_