atomic.cuh 8.52 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2019 by Contributors
 * \file array/cuda/atomic.cuh
 * \brief Atomic functions
 */
6
7
#ifndef DGL_ARRAY_CUDA_ATOMIC_CUH_
#define DGL_ARRAY_CUDA_ATOMIC_CUH_
8
9

#include <cuda_runtime.h>
10
11
12
#include <cassert>
#include "fp16.cuh"

13
14
15
#if __CUDA_ARCH__ >= 600
#include <cuda_fp16.h>
#endif
16
17
18
19
20
21
22
23

namespace dgl {
namespace aten {
namespace cuda {

// Type trait for selecting code type
template <int Bytes> struct Code { };

24
template <> struct Code<2> {
25
  typedef unsigned short int Type;  // NOLINT
26
27
};

28
template <> struct Code<4> {
29
  typedef unsigned int Type;  // NOLINT
30
31
32
};

template <> struct Code<8> {
33
  typedef unsigned long long int Type;  // NOLINT
34
35
36
37
38
39
40
41
42
43
44
45
46
};

// Helper class for converting to/from atomicCAS compatible types.
template <typename T> struct Cast {
  typedef typename Code<sizeof(T)>::Type Type;
  static __device__ __forceinline__ Type Encode(T val) {
    return static_cast<Type>(val);
  }
  static __device__ __forceinline__ T Decode(Type code) {
    return static_cast<T>(code);
  }
};

47
48
49
50
51
52
53
54
55
56
57
58
#ifdef USE_FP16
template <> struct Cast<half> {
  typedef Code<sizeof(half)>::Type Type;
  static __device__ __forceinline__ Type Encode(half val) {
    return __half_as_ushort(val);
  }
  static __device__ __forceinline__ half Decode(Type code) {
    return __ushort_as_half(code);
  }
};
#endif

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
template <> struct Cast<float> {
  typedef Code<sizeof(float)>::Type Type;
  static __device__ __forceinline__ Type Encode(float val) {
    return __float_as_uint(val);
  }
  static __device__ __forceinline__ float Decode(Type code) {
    return __uint_as_float(code);
  }
};

template <> struct Cast<double> {
  typedef Code<sizeof(double)>::Type Type;
  static __device__ __forceinline__ Type Encode(double val) {
    return __double_as_longlong(val);
  }
  static __device__ __forceinline__ double Decode(Type code) {
    return __longlong_as_double(code);
  }
};

79
80
81
82
static __device__ __forceinline__ unsigned short int atomicCASshort(  // NOLINT
    unsigned short int *address,                                      // NOLINT
    unsigned short int compare,                                       // NOLINT
    unsigned short int val) {                                         // NOLINT
83
  static_assert(CUDART_VERSION >= 10000, "Requires at least CUDA 10");
84
85
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
  return atomicCAS(address, compare, val);
86
87
88
89
90
91
92
#else
  (void)address;
  (void)compare;
  (void)val;
  printf("Atomic operations are not supported for half precision (FP16) "
      "on this GPU.\n");
  __trap();
93
  return val;
94
#endif  // (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
95
96
}

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#define DEFINE_ATOMIC(NAME) \
  template <typename T>                                          \
  __device__ __forceinline__ T Atomic##NAME(T* addr, T val) {    \
    typedef typename Cast<T>::Type CT;                           \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                \
    CT old = *addr_as_ui;                                        \
    CT assumed = old;                                            \
    do {                                                         \
      assumed = old;                                             \
      old = atomicCAS(addr_as_ui, assumed,                       \
          Cast<T>::Encode(OP(val, Cast<T>::Decode(old))));       \
    } while (assumed != old);                                    \
    return Cast<T>::Decode(old);                                 \
  }

112
113
114
#define DEFINE_ATOMIC_HALF(NAME) \
  template <>                                                    \
  __device__ __forceinline__ half Atomic##NAME<half>(half* addr, half val) {  \
115
    typedef uint16_t CT;                                         \
116
117
118
119
120
121
122
123
124
125
126
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                \
    CT old = *addr_as_ui;                                        \
    CT assumed = old;                                            \
    do {                                                         \
      assumed = old;                                             \
      old = atomicCASshort(addr_as_ui, assumed,                  \
          Cast<half>::Encode(OP(val, Cast<half>::Decode(old)))); \
    } while (assumed != old);                                    \
    return Cast<half>::Decode(old);                              \
  }

127
128
#define OP(a, b) max(a, b)
DEFINE_ATOMIC(Max)
129
130
131
#ifdef USE_FP16
DEFINE_ATOMIC_HALF(Max)
#endif  // USE_FP16
132
133
134
135
#undef OP

#define OP(a, b) min(a, b)
DEFINE_ATOMIC(Min)
136
137
138
#ifdef USE_FP16
DEFINE_ATOMIC_HALF(Min)
#endif  // USE_FP16
139
140
141
142
143
144
#undef OP

#define OP(a, b) a + b
DEFINE_ATOMIC(Add)
#undef OP

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

/**
* \brief Performs an atomic compare-and-swap on 64 bit integers. That is,
* it the word `old` at the memory location `address`, computes
* `(old == compare ? val : old)` , and stores the result back to memory at
* the same address.
*
* \param address The address to perform the atomic operation on.
* \param compare The value to compare to.
* \param val The new value to conditionally store.
*
* \return The old value at the address.
*/
inline __device__ int64_t AtomicCAS(
    int64_t * const address,
    const int64_t compare,
    const int64_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = unsigned long long int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicCAS(reinterpret_cast<Type*>(address),
                   static_cast<Type>(compare),
                   static_cast<Type>(val));
}

/**
* \brief Performs an atomic compare-and-swap on 32 bit integers. That is,
* it the word `old` at the memory location `address`, computes
* `(old == compare ? val : old)` , and stores the result back to memory at
* the same address.
*
* \param address The address to perform the atomic operation on.
* \param compare The value to compare to.
* \param val The new value to conditionally store.
*
* \return The old value at the address.
*/
inline __device__ int32_t AtomicCAS(
    int32_t * const address,
    const int32_t compare,
    const int32_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicCAS(reinterpret_cast<Type*>(address),
                   static_cast<Type>(compare),
                   static_cast<Type>(val));
}

inline __device__ int64_t AtomicMax(
    int64_t * const address,
    const int64_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = unsigned long long int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicMax(reinterpret_cast<Type*>(address),
                   static_cast<Type>(val));
}

inline __device__ int32_t AtomicMax(
    int32_t * const address,
    const int32_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicMax(reinterpret_cast<Type*>(address),
                   static_cast<Type>(val));
}


223
224
template <>
__device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
225
#if __CUDA_ARCH__ >= 200
226
  return atomicAdd(addr, val);
227
#else
228
229
230
231
232
233
234
235
236
237
238
  typedef float T;
  typedef typename Cast<T>::Type CT;
  CT* addr_as_ui = reinterpret_cast<CT*>(addr);
  CT old = *addr_as_ui;
  CT assumed = old;
  do {
    assumed = old;
    old = atomicCAS(addr_as_ui, assumed,
        Cast<T>::Encode(Cast<T>::Decode(old) + val));
  } while (assumed != old);
  return Cast<T>::Decode(old);
239
#endif  // __CUDA_ARCH__
240
}
241
242
243

template <>
__device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
244
#if __CUDA_ARCH__ >= 600
245
  return atomicAdd(addr, val);
246
#else
247
248
249
250
251
252
253
254
255
256
257
  typedef double T;
  typedef typename Cast<T>::Type CT;
  CT* addr_as_ui = reinterpret_cast<CT*>(addr);
  CT old = *addr_as_ui;
  CT assumed = old;
  do {
    assumed = old;
    old = atomicCAS(addr_as_ui, assumed,
        Cast<T>::Encode(Cast<T>::Decode(old) + val));
  } while (assumed != old);
  return Cast<T>::Decode(old);
258
#endif
259
}
260

261
#ifdef USE_FP16
262
263
#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000
template <>
264
__device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
265
// half make sure we have half support
266
267
#if __CUDA_ARCH__ >= 700
  return atomicAdd(addr, val);
268
#else
269
270
271
272
273
274
  (void)addr;
  (void)val;
  printf("Atomic operations are not supported for half precision (FP16) "
      "on this GPU.\n");
  __trap();
  return val;
275
#endif  // __CUDA_ARCH__ >= 700
276
277
278
}
#endif  // defined(CUDART_VERSION) && CUDART_VERSION >= 10000
#endif  // USE_FP16
279
280
281
282
283
284


}  // namespace cuda
}  // namespace aten
}  // namespace dgl

285
#endif  // DGL_ARRAY_CUDA_ATOMIC_CUH_