"src/diffusers/models/controlnet_sparsectrl.py" did not exist on "3eb498e7b4868bca7460d41cda52d33c3ede5502"
atomic.cuh 12.2 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2019 by Contributors
4
5
 * @file array/cuda/atomic.cuh
 * @brief Atomic functions
6
 */
7
8
#ifndef DGL_ARRAY_CUDA_ATOMIC_CUH_
#define DGL_ARRAY_CUDA_ATOMIC_CUH_
9

sangwzh's avatar
sangwzh committed
10
#include <hip/hip_runtime.h>
11

12
#include <cassert>
13
14
#include <cstdint>
#include <cstdio>
15

16
#include "bf16.cuh"
17
#include "fp16.cuh"
18

sangwzh's avatar
sangwzh committed
19
20
#if __HIPCC__ 
#include <hip/hip_fp16.h>
21
#endif
22
23
24
25
26
27

namespace dgl {
namespace aten {
namespace cuda {

// Type trait for selecting code type
28
29
template <int Bytes>
struct Code {};
30

31
32
template <>
struct Code<2> {
33
  typedef unsigned short int Type;  // NOLINT
34
35
};

36
37
template <>
struct Code<4> {
38
  typedef unsigned int Type;  // NOLINT
39
40
};

41
42
template <>
struct Code<8> {
43
  typedef unsigned long long int Type;  // NOLINT
44
45
46
};

// Helper class for converting to/from atomicCAS compatible types.
47
48
template <typename T>
struct Cast {
49
50
51
52
53
54
55
56
57
  typedef typename Code<sizeof(T)>::Type Type;
  static __device__ __forceinline__ Type Encode(T val) {
    return static_cast<Type>(val);
  }
  static __device__ __forceinline__ T Decode(Type code) {
    return static_cast<T>(code);
  }
};

58
59
template <>
struct Cast<half> {
sangwzh's avatar
sangwzh committed
60
61
  typedef half Type;
  static __host__ __device__ __forceinline__ Type Encode(half val) {
62
63
    return __half_as_ushort(val);
  }
sangwzh's avatar
sangwzh committed
64
  static __host__ __device__ __forceinline__ half Decode(Type code) {
65
66
67
    return __ushort_as_half(code);
  }
};
68
69

#if BF16_ENABLED
70
template <>
sangwzh's avatar
sangwzh committed
71
72
73
74
struct Cast<__hip_bfloat16> {
  typedef __hip_bfloat16 Type;
  static __host__ __device__ __forceinline__ Type Encode(__hip_bfloat16 val) {
#if defined(__HIP_DEVICE_COMPILE__)
75
76
    return __bfloat16_as_ushort(val);
#else
77
78
    printf(
        "Atomic operations are not supported for bfloat16 (BF16) "
79
        "on GPUs with compute capability less than 8.0.\n");
sangwzh's avatar
sangwzh committed
80
    // //__trap();
81
82
83
    return static_cast<Type>(0);
#endif
  }
sangwzh's avatar
sangwzh committed
84
85
  static __host__ __device__ __forceinline__ __hip_bfloat16 Decode(Type code) {
#if defined(__HIP_DEVICE_COMPILE__)
86
87
    return __ushort_as_bfloat16(code);
#else
88
89
    printf(
        "Atomic operations are not supported for bfloat16 (BF16) "
90
        "on GPUs with compute capability less than 8.0.\n");
sangwzh's avatar
sangwzh committed
91
92
    //__trap();
    return static_cast<__hip_bfloat16>(0.0f);
93
#endif
94
95
96
  }
};
#endif  // BF16_ENABLED
97

98
99
template <>
struct Cast<float> {
100
101
102
103
104
105
106
107
108
  typedef Code<sizeof(float)>::Type Type;
  static __device__ __forceinline__ Type Encode(float val) {
    return __float_as_uint(val);
  }
  static __device__ __forceinline__ float Decode(Type code) {
    return __uint_as_float(code);
  }
};

109
110
template <>
struct Cast<double> {
111
112
113
114
115
116
117
118
119
  typedef Code<sizeof(double)>::Type Type;
  static __device__ __forceinline__ Type Encode(double val) {
    return __double_as_longlong(val);
  }
  static __device__ __forceinline__ double Decode(Type code) {
    return __longlong_as_double(code);
  }
};

sangwzh's avatar
sangwzh committed
120
static __host__ __device__ __forceinline__ unsigned short int atomicCASshort(  // NOLINT
121
    unsigned short int* address,                                      // NOLINT
122
123
    unsigned short int compare,                                       // NOLINT
    unsigned short int val) {                                         // NOLINT
sangwzh's avatar
sangwzh committed
124
125
  static_assert(DTKRT_VERSION >= 10000, "Requires at least CUDA 10");
#if defined(__HIP_DEVICE_COMPILE__) && 0
126
  return atomicCAS(address, compare, val);
127
128
129
130
#else
  (void)address;
  (void)compare;
  (void)val;
131
132
  printf(
      "Atomic operations are not supported for half precision (FP16) "
133
      "on this GPU.\n");
sangwzh's avatar
sangwzh committed
134
  abort();
135
  return val;
sangwzh's avatar
sangwzh committed
136
#endif  // (defined(__HIP_DEVICE_COMPILE__) 
137
138
}

139
140
141
142
143
144
145
146
147
148
149
150
151
152
#define DEFINE_ATOMIC(NAME)                                   \
  template <typename T>                                       \
  __device__ __forceinline__ T Atomic##NAME(T* addr, T val) { \
    typedef typename Cast<T>::Type CT;                        \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);             \
    CT old = *addr_as_ui;                                     \
    CT assumed = old;                                         \
    do {                                                      \
      assumed = old;                                          \
      old = atomicCAS(                                        \
          addr_as_ui, assumed,                                \
          Cast<T>::Encode(OP(val, Cast<T>::Decode(old))));    \
    } while (assumed != old);                                 \
    return Cast<T>::Decode(old);                              \
153
154
  }

155
156
157
158
159
160
161
162
163
164
165
166
#define DEFINE_ATOMIC_16BIT(NAME, dtype)                           \
  template <>                                                      \
  __device__ __forceinline__ dtype Atomic##NAME<dtype>(            \
      dtype * addr, dtype val) {                                   \
    typedef uint16_t CT;                                           \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                  \
    CT old = *addr_as_ui;                                          \
    CT assumed = old;                                              \
    do {                                                           \
      assumed = old;                                               \
      old = atomicCASshort(                                        \
          addr_as_ui, assumed,                                     \
167
          Cast<dtype>::Encode(OP(val, Cast<dtype>::Decode(old)))); \
168
169
    } while (assumed != old);                                      \
    return Cast<dtype>::Decode(old);                               \
170
171
  }

sangwzh's avatar
sangwzh committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#define DEFINE_ATOMIC_16BIT_BF(NAME, dtype)                           \
  template <>                                                      \
  __device__ __forceinline__ dtype Atomic##NAME<dtype>(            \
      dtype * addr, dtype val) {                                   \
    typedef uint16_t CT;                                           \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                  \
    CT old = *addr_as_ui;                                          \
    CT assumed = old;                                              \
    do {                                                           \
      assumed = old;                                               \
      old = atomicCASshort(                                        \
          addr_as_ui, assumed,                                     \
          Cast<dtype>::Encode(max((double)val, (double)dtype(old)))); \
    } while (assumed != old);                                      \
    return Cast<dtype>::Decode(old);                               \
  }

#define DEFINE_ATOMIC_16BIT_Min(NAME, dtype)                           \
  template <>                                                      \
  __device__ __forceinline__ dtype Atomic##NAME<dtype>(            \
      dtype * addr, dtype val) {                                   \
    typedef uint16_t CT;                                           \
    CT* addr_as_ui = reinterpret_cast<CT*>(addr);                  \
    CT old = *addr_as_ui;                                          \
    CT assumed = old;                                              \
    do {                                                           \
      assumed = old;                                               \
      old = atomicCASshort(                                        \
          addr_as_ui, assumed,                                     \
          Cast<dtype>::Encode(min(val, dtype(old)))); \
    } while (assumed != old);                                      \
    return Cast<dtype>::Decode(old);                               \
  }

#define OP(a, b) max((double)a, (double)b)
207
DEFINE_ATOMIC(Max)
208
209
DEFINE_ATOMIC_16BIT(Max, half)
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
210
DEFINE_ATOMIC_16BIT_BF(Max, __hip_bfloat16)
211
#endif  // BF16_ENABLED
212
213
#undef OP

sangwzh's avatar
sangwzh committed
214
#define OP(a, b) min((double)a, (double)b)
215
DEFINE_ATOMIC(Min)
216
217
DEFINE_ATOMIC_16BIT(Min, half)
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
218
DEFINE_ATOMIC_16BIT_BF(Min, __hip_bfloat16)
219
#endif  // BF16_ENABLED
220
221
222
223
224
225
#undef OP

#define OP(a, b) a + b
DEFINE_ATOMIC(Add)
#undef OP

226
/**
227
228
229
230
231
232
233
234
235
236
237
238
239
 * @brief Performs an atomic compare-and-swap on 64 bit integers. That is,
 * it the word `old` at the memory location `address`, computes
 * `(old == compare ? val : old)` , and stores the result back to memory at
 * the same address.
 *
 * @param address The address to perform the atomic operation on.
 * @param compare The value to compare to.
 * @param val The new value to conditionally store.
 *
 * @return The old value at the address.
 */
inline __device__ int64_t
AtomicCAS(int64_t* const address, const int64_t compare, const int64_t val) {
240
  // match the type of "::atomicCAS", so ignore lint warning
241
  using Type = unsigned long long int;  // NOLINT
242
243
244

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

245
246
247
  return atomicCAS(
      reinterpret_cast<Type*>(address), static_cast<Type>(compare),
      static_cast<Type>(val));
248
249
250
}

/**
251
252
253
254
255
256
257
258
259
260
261
262
263
 * @brief Performs an atomic compare-and-swap on 32 bit integers. That is,
 * it the word `old` at the memory location `address`, computes
 * `(old == compare ? val : old)` , and stores the result back to memory at
 * the same address.
 *
 * @param address The address to perform the atomic operation on.
 * @param compare The value to compare to.
 * @param val The new value to conditionally store.
 *
 * @return The old value at the address.
 */
inline __device__ int32_t
AtomicCAS(int32_t* const address, const int32_t compare, const int32_t val) {
264
  // match the type of "::atomicCAS", so ignore lint warning
265
  using Type = int;  // NOLINT
266
267
268

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

269
270
271
  return atomicCAS(
      reinterpret_cast<Type*>(address), static_cast<Type>(compare),
      static_cast<Type>(val));
272
273
}

274
inline __device__ int64_t AtomicMax(int64_t* const address, const int64_t val) {
275
  // match the type of "::atomicCAS", so ignore lint warning
276
  using Type = unsigned long long int;  // NOLINT
277
278
279

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

280
  return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
281
282
}

283
inline __device__ int32_t AtomicMax(int32_t* const address, const int32_t val) {
284
  // match the type of "::atomicCAS", so ignore lint warning
285
  using Type = int;  // NOLINT
286
287
288

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

289
  return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
290
291
}

292
293
template <>
__device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
sangwzh's avatar
sangwzh committed
294
#if __HIP_DEVICE_COMPILE__ 
295
  return atomicAdd(addr, val);
296
#else
297
298
299
300
301
302
303
  typedef float T;
  typedef typename Cast<T>::Type CT;
  CT* addr_as_ui = reinterpret_cast<CT*>(addr);
  CT old = *addr_as_ui;
  CT assumed = old;
  do {
    assumed = old;
304
305
    old = atomicCAS(
        addr_as_ui, assumed, Cast<T>::Encode(Cast<T>::Decode(old) + val));
306
307
  } while (assumed != old);
  return Cast<T>::Decode(old);
sangwzh's avatar
sangwzh committed
308
#endif  // __HIP_DEVICE_COMPILE__
309
}
310
311
312

template <>
__device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
sangwzh's avatar
sangwzh committed
313
#if __HIP_DEVICE_COMPILE__ 
314
  return atomicAdd(addr, val);
315
#else
316
317
318
319
320
321
322
  typedef double T;
  typedef typename Cast<T>::Type CT;
  CT* addr_as_ui = reinterpret_cast<CT*>(addr);
  CT old = *addr_as_ui;
  CT assumed = old;
  do {
    assumed = old;
323
324
    old = atomicCAS(
        addr_as_ui, assumed, Cast<T>::Encode(Cast<T>::Decode(old) + val));
325
326
  } while (assumed != old);
  return Cast<T>::Decode(old);
327
#endif
328
}
329

sangwzh's avatar
sangwzh committed
330
#if defined(DTKRT_VERSION) && DTKRT_VERSION >= 10000
331
template <>
332
__device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
333
// make sure we have half support
sangwzh's avatar
sangwzh committed
334
#if __HIP_DEVICE_COMPILE__
335
  return atomicAdd(addr, val);
336
#else
337
338
  (void)addr;
  (void)val;
339
340
  printf(
      "Atomic operations are not supported for half precision (FP16) "
341
      "on this GPU.\n");
sangwzh's avatar
sangwzh committed
342
  // //__trap();
343
  return val;
sangwzh's avatar
sangwzh committed
344
#endif  // __HIP_DEVICE_COMPILE__ 
345
}
sangwzh's avatar
sangwzh committed
346
#endif  // defined(DTKRT_VERSION) && DTKRT_VERSION >= 10000
347
348
349

#if BF16_ENABLED
template <>
sangwzh's avatar
sangwzh committed
350
351
__device__ __forceinline__ __hip_bfloat16
AtomicAdd<__hip_bfloat16>(__hip_bfloat16* addr, __hip_bfloat16 val) {
352
// make sure we have bfloat16 support
sangwzh's avatar
sangwzh committed
353
#if defined(__HIP_DEVICE_COMPILE__) 
354
355
356
357
  return atomicAdd(addr, val);
#else
  (void)addr;
  (void)val;
358
359
  printf(
      "Atomic operations are not supported for bfloat16 (BF16) "
360
      "on GPUs with compute capability less than 8.0.\n");
sangwzh's avatar
sangwzh committed
361
  //__trap();
362
  return val;
sangwzh's avatar
sangwzh committed
363
#endif  // defined(__HIP_DEVICE_COMPILE__) 
364
365
}
#endif  // BF16_ENABLED
366
367
368
369
370

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

371
#endif  // DGL_ARRAY_CUDA_ATOMIC_CUH_