atomic.cuh 9.95 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file array/cuda/atomic.cuh
 * @brief Atomic functions
5
 */
6
7
#ifndef DGL_ARRAY_CUDA_ATOMIC_CUH_
#define DGL_ARRAY_CUDA_ATOMIC_CUH_
8
9

#include <cuda_runtime.h>
10
11
#include <cassert>
#include "fp16.cuh"
12
#include "bf16.cuh"
13

14
15
16
#if __CUDA_ARCH__ >= 600
#include <cuda_fp16.h>
#endif
17
18
19
20
21
22
23
24

namespace dgl {
namespace aten {
namespace cuda {

// Type trait for selecting code type
template <int Bytes> struct Code { };

25
template <> struct Code<2> {
26
  typedef unsigned short int Type;  // NOLINT
27
28
};

29
template <> struct Code<4> {
30
  typedef unsigned int Type;  // NOLINT
31
32
33
};

template <> struct Code<8> {
34
  typedef unsigned long long int Type;  // NOLINT
35
36
37
38
39
40
41
42
43
44
45
46
47
};

// Helper class for converting to/from atomicCAS compatible types.
template <typename T> struct Cast {
  typedef typename Code<sizeof(T)>::Type Type;
  static __device__ __forceinline__ Type Encode(T val) {
    return static_cast<Type>(val);
  }
  static __device__ __forceinline__ T Decode(Type code) {
    return static_cast<T>(code);
  }
};

48
49
50
51
52
53
54
55
56
template <> struct Cast<half> {
  typedef Code<sizeof(half)>::Type Type;
  static __device__ __forceinline__ Type Encode(half val) {
    return __half_as_ushort(val);
  }
  static __device__ __forceinline__ half Decode(Type code) {
    return __ushort_as_half(code);
  }
};
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

#if BF16_ENABLED
template <> struct Cast<__nv_bfloat16> {
  typedef Code<sizeof(__nv_bfloat16)>::Type Type;
  static __device__ __forceinline__ Type Encode(__nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
    return __bfloat16_as_ushort(val);
#else
    printf("Atomic operations are not supported for bfloat16 (BF16) "
        "on GPUs with compute capability less than 8.0.\n");
    __trap();
    return static_cast<Type>(0);
#endif
  }
  static __device__ __forceinline__ __nv_bfloat16 Decode(Type code) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
    return __ushort_as_bfloat16(code);
#else
    printf("Atomic operations are not supported for bfloat16 (BF16) "
        "on GPUs with compute capability less than 8.0.\n");
    __trap();
    return static_cast<__nv_bfloat16>(0.0f);
79
#endif
80
81
82
  }
};
#endif  // BF16_ENABLED
83

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
template <> struct Cast<float> {
  typedef Code<sizeof(float)>::Type Type;
  static __device__ __forceinline__ Type Encode(float val) {
    return __float_as_uint(val);
  }
  static __device__ __forceinline__ float Decode(Type code) {
    return __uint_as_float(code);
  }
};

template <> struct Cast<double> {
  typedef Code<sizeof(double)>::Type Type;
  static __device__ __forceinline__ Type Encode(double val) {
    return __double_as_longlong(val);
  }
  static __device__ __forceinline__ double Decode(Type code) {
    return __longlong_as_double(code);
  }
};

104
105
106
107
static __device__ __forceinline__ unsigned short int atomicCASshort(  // NOLINT
    unsigned short int *address,                                      // NOLINT
    unsigned short int compare,                                       // NOLINT
    unsigned short int val) {                                         // NOLINT
108
  static_assert(CUDART_VERSION >= 10000, "Requires at least CUDA 10");
109
110
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
  return atomicCAS(address, compare, val);
111
112
113
114
115
116
117
#else
  (void)address;
  (void)compare;
  (void)val;
  printf("Atomic operations are not supported for half precision (FP16) "
      "on this GPU.\n");
  __trap();
118
  return val;
119
#endif  // (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
120
121
}

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#define DEFINE_ATOMIC(NAME) \
  template <typename T>                                          \
  __device__ __forceinline__ T Atomic##NAME(T* addr, T val) {    \
    typedef typename Cast<T>::Type CT;                           \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                \
    CT old = *addr_as_ui;                                        \
    CT assumed = old;                                            \
    do {                                                         \
      assumed = old;                                             \
      old = atomicCAS(addr_as_ui, assumed,                       \
          Cast<T>::Encode(OP(val, Cast<T>::Decode(old))));       \
    } while (assumed != old);                                    \
    return Cast<T>::Decode(old);                                 \
  }

137
#define DEFINE_ATOMIC_16BIT(NAME, dtype) \
138
  template <>                                                    \
139
  __device__ __forceinline__ dtype Atomic##NAME<dtype>(dtype* addr, dtype val) {  \
140
    typedef uint16_t CT;                                         \
141
142
143
144
145
146
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                \
    CT old = *addr_as_ui;                                        \
    CT assumed = old;                                            \
    do {                                                         \
      assumed = old;                                             \
      old = atomicCASshort(addr_as_ui, assumed,                  \
147
          Cast<dtype>::Encode(OP(val, Cast<dtype>::Decode(old)))); \
148
    } while (assumed != old);                                    \
149
    return Cast<dtype>::Decode(old);                              \
150
151
  }

152
153
#define OP(a, b) max(a, b)
DEFINE_ATOMIC(Max)
154
155
156
157
DEFINE_ATOMIC_16BIT(Max, half)
#if BF16_ENABLED
DEFINE_ATOMIC_16BIT(Max, __nv_bfloat16)
#endif  // BF16_ENABLED
158
159
160
161
#undef OP

#define OP(a, b) min(a, b)
DEFINE_ATOMIC(Min)
162
163
164
165
DEFINE_ATOMIC_16BIT(Min, half)
#if BF16_ENABLED
DEFINE_ATOMIC_16BIT(Min, __nv_bfloat16)
#endif  // BF16_ENABLED
166
167
168
169
170
171
#undef OP

#define OP(a, b) a + b
DEFINE_ATOMIC(Add)
#undef OP

172
173

/**
174
* @brief Performs an atomic compare-and-swap on 64 bit integers. That is,
175
176
177
178
* it the word `old` at the memory location `address`, computes
* `(old == compare ? val : old)` , and stores the result back to memory at
* the same address.
*
179
180
181
* @param address The address to perform the atomic operation on.
* @param compare The value to compare to.
* @param val The new value to conditionally store.
182
*
183
* @return The old value at the address.
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
*/
inline __device__ int64_t AtomicCAS(
    int64_t * const address,
    const int64_t compare,
    const int64_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = unsigned long long int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicCAS(reinterpret_cast<Type*>(address),
                   static_cast<Type>(compare),
                   static_cast<Type>(val));
}

/**
200
* @brief Performs an atomic compare-and-swap on 32 bit integers. That is,
201
202
203
204
* it the word `old` at the memory location `address`, computes
* `(old == compare ? val : old)` , and stores the result back to memory at
* the same address.
*
205
206
207
* @param address The address to perform the atomic operation on.
* @param compare The value to compare to.
* @param val The new value to conditionally store.
208
*
209
* @return The old value at the address.
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
*/
inline __device__ int32_t AtomicCAS(
    int32_t * const address,
    const int32_t compare,
    const int32_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicCAS(reinterpret_cast<Type*>(address),
                   static_cast<Type>(compare),
                   static_cast<Type>(val));
}

inline __device__ int64_t AtomicMax(
    int64_t * const address,
    const int64_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = unsigned long long int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicMax(reinterpret_cast<Type*>(address),
                   static_cast<Type>(val));
}

inline __device__ int32_t AtomicMax(
    int32_t * const address,
    const int32_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicMax(reinterpret_cast<Type*>(address),
                   static_cast<Type>(val));
}


250
251
template <>
__device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
252
#if __CUDA_ARCH__ >= 200
253
  return atomicAdd(addr, val);
254
#else
255
256
257
258
259
260
261
262
263
264
265
  typedef float T;
  typedef typename Cast<T>::Type CT;
  CT* addr_as_ui = reinterpret_cast<CT*>(addr);
  CT old = *addr_as_ui;
  CT assumed = old;
  do {
    assumed = old;
    old = atomicCAS(addr_as_ui, assumed,
        Cast<T>::Encode(Cast<T>::Decode(old) + val));
  } while (assumed != old);
  return Cast<T>::Decode(old);
266
#endif  // __CUDA_ARCH__
267
}
268
269
270

template <>
__device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
271
#if __CUDA_ARCH__ >= 600
272
  return atomicAdd(addr, val);
273
#else
274
275
276
277
278
279
280
281
282
283
284
  typedef double T;
  typedef typename Cast<T>::Type CT;
  CT* addr_as_ui = reinterpret_cast<CT*>(addr);
  CT old = *addr_as_ui;
  CT assumed = old;
  do {
    assumed = old;
    old = atomicCAS(addr_as_ui, assumed,
        Cast<T>::Encode(Cast<T>::Decode(old) + val));
  } while (assumed != old);
  return Cast<T>::Decode(old);
285
#endif
286
}
287
288
289

#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000
template <>
290
__device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
291
// make sure we have half support
292
293
#if __CUDA_ARCH__ >= 700
  return atomicAdd(addr, val);
294
#else
295
296
297
298
299
300
  (void)addr;
  (void)val;
  printf("Atomic operations are not supported for half precision (FP16) "
      "on this GPU.\n");
  __trap();
  return val;
301
#endif  // __CUDA_ARCH__ >= 700
302
303
}
#endif  // defined(CUDART_VERSION) && CUDART_VERSION >= 10000
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321

#if BF16_ENABLED
template <>
__device__ __forceinline__ __nv_bfloat16 AtomicAdd<__nv_bfloat16>(
    __nv_bfloat16* addr, __nv_bfloat16 val) {
// make sure we have bfloat16 support
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  return atomicAdd(addr, val);
#else
  (void)addr;
  (void)val;
  printf("Atomic operations are not supported for bfloat16 (BF16) "
      "on GPUs with compute capability less than 8.0.\n");
  __trap();
  return val;
#endif  // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
}
#endif  // BF16_ENABLED
322
323
324
325
326
327


}  // namespace cuda
}  // namespace aten
}  // namespace dgl

328
#endif  // DGL_ARRAY_CUDA_ATOMIC_CUH_