atomic.cuh 7.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
/*!
 *  Copyright (c) 2019 by Contributors
 * \file array/cuda/atomic.cuh
 * \brief Atomic functions
 */
#ifndef DGL_ARRAY_CUDA_ATOMIC_H_
#define DGL_ARRAY_CUDA_ATOMIC_H_

#include <cuda_runtime.h>
10
11
12
#include <cassert>
#include "fp16.cuh"

13
14
15
#if __CUDA_ARCH__ >= 600
#include <cuda_fp16.h>
#endif
16
17
18
19
20
21
22
23

namespace dgl {
namespace aten {
namespace cuda {

// Type trait for selecting code type
template <int Bytes> struct Code { };

24
25
26
27
template <> struct Code<2> {
  typedef unsigned short int Type;
};

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
template <> struct Code<4> {
  typedef unsigned int Type;
};

template <> struct Code<8> {
  typedef unsigned long long int Type;
};

// Helper class for converting to/from atomicCAS compatible types.
template <typename T> struct Cast {
  typedef typename Code<sizeof(T)>::Type Type;
  static __device__ __forceinline__ Type Encode(T val) {
    return static_cast<Type>(val);
  }
  static __device__ __forceinline__ T Decode(Type code) {
    return static_cast<T>(code);
  }
};

47
48
49
50
51
52
53
54
55
56
57
58
#ifdef USE_FP16
template <> struct Cast<half> {
  typedef Code<sizeof(half)>::Type Type;
  static __device__ __forceinline__ Type Encode(half val) {
    return __half_as_ushort(val);
  }
  static __device__ __forceinline__ half Decode(Type code) {
    return __ushort_as_half(code);
  }
};
#endif

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
template <> struct Cast<float> {
  typedef Code<sizeof(float)>::Type Type;
  static __device__ __forceinline__ Type Encode(float val) {
    return __float_as_uint(val);
  }
  static __device__ __forceinline__ float Decode(Type code) {
    return __uint_as_float(code);
  }
};

template <> struct Cast<double> {
  typedef Code<sizeof(double)>::Type Type;
  static __device__ __forceinline__ Type Encode(double val) {
    return __double_as_longlong(val);
  }
  static __device__ __forceinline__ double Decode(Type code) {
    return __longlong_as_double(code);
  }
};

79
80
81
82
83
84
85
86
87
88
89
90
static __device__ __forceinline__ unsigned short int atomicCASshort(
    unsigned short int *address,
    unsigned short int compare,
    unsigned short int val) {
#if (defined(CUDART_VERSION) && (CUDART_VERSION > 10000))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
  return atomicCAS(address, compare, val);
#endif  // (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
#endif  // (defined(CUDART_VERSION) && (CUDART_VERSION > 10000))
  return val;
}

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#define DEFINE_ATOMIC(NAME) \
  template <typename T>                                          \
  __device__ __forceinline__ T Atomic##NAME(T* addr, T val) {    \
    typedef typename Cast<T>::Type CT;                           \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                \
    CT old = *addr_as_ui;                                        \
    CT assumed = old;                                            \
    do {                                                         \
      assumed = old;                                             \
      old = atomicCAS(addr_as_ui, assumed,                       \
          Cast<T>::Encode(OP(val, Cast<T>::Decode(old))));       \
    } while (assumed != old);                                    \
    return Cast<T>::Decode(old);                                 \
  }

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#define DEFINE_ATOMIC_HALF(NAME) \
  template <>                                                    \
  __device__ __forceinline__ half Atomic##NAME<half>(half* addr, half val) {  \
    typedef unsigned short int CT;                               \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                \
    CT old = *addr_as_ui;                                        \
    CT assumed = old;                                            \
    do {                                                         \
      assumed = old;                                             \
      old = atomicCASshort(addr_as_ui, assumed,                  \
          Cast<half>::Encode(OP(val, Cast<half>::Decode(old)))); \
    } while (assumed != old);                                    \
    return Cast<half>::Decode(old);                              \
  }

121
122
#define OP(a, b) max(a, b)
DEFINE_ATOMIC(Max)
123
124
125
#ifdef USE_FP16
DEFINE_ATOMIC_HALF(Max)
#endif  // USE_FP16
126
127
128
129
#undef OP

#define OP(a, b) min(a, b)
DEFINE_ATOMIC(Min)
130
131
132
#ifdef USE_FP16
DEFINE_ATOMIC_HALF(Min)
#endif  // USE_FP16
133
134
135
136
137
138
#undef OP

#define OP(a, b) a + b
DEFINE_ATOMIC(Add)
#undef OP

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

/**
* \brief Performs an atomic compare-and-swap on 64 bit integers. That is,
* it the word `old` at the memory location `address`, computes
* `(old == compare ? val : old)` , and stores the result back to memory at
* the same address.
*
* \param address The address to perform the atomic operation on.
* \param compare The value to compare to.
* \param val The new value to conditionally store.
*
* \return The old value at the address.
*/
inline __device__ int64_t AtomicCAS(
    int64_t * const address,
    const int64_t compare,
    const int64_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = unsigned long long int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicCAS(reinterpret_cast<Type*>(address),
                   static_cast<Type>(compare),
                   static_cast<Type>(val));
}

/**
* \brief Performs an atomic compare-and-swap on 32 bit integers. That is,
* it the word `old` at the memory location `address`, computes
* `(old == compare ? val : old)` , and stores the result back to memory at
* the same address.
*
* \param address The address to perform the atomic operation on.
* \param compare The value to compare to.
* \param val The new value to conditionally store.
*
* \return The old value at the address.
*/
inline __device__ int32_t AtomicCAS(
    int32_t * const address,
    const int32_t compare,
    const int32_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicCAS(reinterpret_cast<Type*>(address),
                   static_cast<Type>(compare),
                   static_cast<Type>(val));
}

inline __device__ int64_t AtomicMax(
    int64_t * const address,
    const int64_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = unsigned long long int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicMax(reinterpret_cast<Type*>(address),
                   static_cast<Type>(val));
}

inline __device__ int32_t AtomicMax(
    int32_t * const address,
    const int32_t val) {
  // match the type of "::atomicCAS", so ignore lint warning
  using Type = int; // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicMax(reinterpret_cast<Type*>(address),
                   static_cast<Type>(val));
}


217
218
template <>
__device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
219
#if __CUDA_ARCH__ >= 200
220
  return atomicAdd(addr, val);
221
222
#else
  return *addr + val;
223
#endif  // __CUDA_ARCH__
224
}
225
226
227

template <>
__device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
228
#if __CUDA_ARCH__ >= 600
229
  return atomicAdd(addr, val);
230
231
#else
  return *addr + val;
232
#endif
233
}
234

235
#ifdef USE_FP16
236
237
#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000
template <>
238
__device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
239
240
#if __CUDA_ARCH__ >= 700
  return atomicAdd(addr, val);
241
242
#else
  return *addr + val;
243
#endif  // __CUDA_ARCH__
244
245
246
}
#endif  // defined(CUDART_VERSION) && CUDART_VERSION >= 10000
#endif  // USE_FP16
247
248
249
250
251
252
253


}  // namespace cuda
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CUDA_ATOMIC_H_