"nndet/csrc/vscode:/vscode.git/clone" did not exist on "7246044d8824f7b3f6c243db054b61420212ad05"
utils.cuh 30.3 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
#define TRANSFORMER_ENGINE_COMMON_UTILS_CUH_

#include <cuda_fp16.h>
yuguo's avatar
yuguo committed
11
12
13
14
15

#ifdef __HIP_PLATFORM_AMD__
#ifndef __HIPCC_RTC__
#include <cstdint>
#else
16
#include <hip/amd_detail/hip_assert.h>
yuguo's avatar
yuguo committed
17
18
19
20
21
using namespace __hip_internal;
#endif
#endif

#include <cuda_bf16.h>
Tim Moon's avatar
Tim Moon committed
22
#include <cuda_fp8.h>
Przemek Tredak's avatar
Przemek Tredak committed
23

24
25
26
27
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif

yuguo's avatar
yuguo committed
28
29
30
#ifdef __HIP_PLATFORM_AMD__
typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2)));
#else
Tim Moon's avatar
Tim Moon committed
31
32
33
34
35
36
37
38
39
40
41
42
43
#if !defined(__CUDACC_RTC__)
#include <cstdint>
#else
// Importing C++ standard headers is a pain with NVRTC
using uint8_t = unsigned char;
using uint16_t = unsigned short int;  // NOLINT(*)
using uint32_t = unsigned int;
using uint64_t = unsigned long long int;  // NOLINT(*)
static_assert(sizeof(uint8_t) == 1);
static_assert(sizeof(uint16_t) == 2);
static_assert(sizeof(uint32_t) == 4);
static_assert(sizeof(uint64_t) == 8);
#endif
Przemek Tredak's avatar
Przemek Tredak committed
44

yuguo's avatar
yuguo committed
45
#endif // __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
46
47
48
49
50
51
////////////////////////////////////////////////////////////////////////////////////////////////////

constexpr uint32_t THREADS_PER_WARP = 32;

////////////////////////////////////////////////////////////////////////////////////////////////////

yuguo's avatar
yuguo committed
52
#if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__)
53
54
inline __device__ float2 operator+(const float2 &a, const float2 &b) {  // NOLINT(*)
  return {a.x + b.x, a.y + b.y};
Przemek Tredak's avatar
Przemek Tredak committed
55
56
57
58
}

////////////////////////////////////////////////////////////////////////////////////////////////////

59
60
61
inline __device__ void operator+=(float2 &a, const float2 &b) {  // NOLINT(*)
  a.x += b.x;
  a.y += b.y;
Przemek Tredak's avatar
Przemek Tredak committed
62
}
yuguo's avatar
yuguo committed
63
#endif
Przemek Tredak's avatar
Przemek Tredak committed
64
65
66

////////////////////////////////////////////////////////////////////////////////////////////////////

67
template <typename T>
Przemek Tredak's avatar
Przemek Tredak committed
68
struct Sum {
69
70
  inline __device__ Sum() {}
  inline __device__ T operator()(const T &a, const T &b) const { return a + b; }
Przemek Tredak's avatar
Przemek Tredak committed
71
72
73
74
};

////////////////////////////////////////////////////////////////////////////////////////////////////

75
76
template <typename T>
inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) {
yuguo's avatar
yuguo committed
77
78
79
#ifdef __HIP_PLATFORM_AMD__
  return __shfl_xor(x, idx, THREADS_PER_WARP);
#else
80
  return __shfl_xor_sync(static_cast<uint32_t>(-1), x, idx);
yuguo's avatar
yuguo committed
81
#endif
Przemek Tredak's avatar
Przemek Tredak committed
82
83
}

84
85
86
template <>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 &x, uint32_t idx) {
  return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)};
Przemek Tredak's avatar
Przemek Tredak committed
87
88
}

89
90
template <typename T>
inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) {
yuguo's avatar
yuguo committed
91
92
93
#ifdef __HIP_PLATFORM_AMD__
  return __shfl_down(x, idx, THREADS_PER_WARP);
#else
94
  return __shfl_down_sync(static_cast<uint32_t>(-1), x, idx);
yuguo's avatar
yuguo committed
95
#endif
Przemek Tredak's avatar
Przemek Tredak committed
96
97
}

98
99
100
template <>
inline __device__ float2 warp_shuffle_down<float2>(const float2 &x, uint32_t idx) {
  return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)};
Przemek Tredak's avatar
Przemek Tredak committed
101
102
103
104
105
106
107
108
109
}

////////////////////////////////////////////////////////////////////////////////////////////////////

namespace transformer_engine {

////////////////////////////////////////////////////////////////////////////////////////////////////

struct uint16 {
110
111
112
113
  uint4 u;
  uint4 v;
  uint4 s;
  uint4 t;
Przemek Tredak's avatar
Przemek Tredak committed
114
115
116
117
118
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct uint8 {
119
120
  uint4 u;
  uint4 v;
Przemek Tredak's avatar
Przemek Tredak committed
121
122
123
124
};

////////////////////////////////////////////////////////////////////////////////////////////////////

125
template <int BYTES>
Przemek Tredak's avatar
Przemek Tredak committed
126
127
struct BytesToType {};

128
template <>
Przemek Tredak's avatar
Przemek Tredak committed
129
struct BytesToType<64> {
130
131
  using Type = uint16;
  static_assert(sizeof(Type) == 64);
Przemek Tredak's avatar
Przemek Tredak committed
132
133
};

134
template <>
Przemek Tredak's avatar
Przemek Tredak committed
135
struct BytesToType<32> {
136
137
  using Type = uint8;
  static_assert(sizeof(Type) == 32);
Przemek Tredak's avatar
Przemek Tredak committed
138
139
};

140
template <>
Przemek Tredak's avatar
Przemek Tredak committed
141
struct BytesToType<16> {
142
143
  using Type = uint4;
  static_assert(sizeof(Type) == 16);
Przemek Tredak's avatar
Przemek Tredak committed
144
145
};

146
template <>
Przemek Tredak's avatar
Przemek Tredak committed
147
struct BytesToType<8> {
148
149
  using Type = uint64_t;
  static_assert(sizeof(Type) == 8);
Przemek Tredak's avatar
Przemek Tredak committed
150
151
};

152
template <>
Przemek Tredak's avatar
Przemek Tredak committed
153
struct BytesToType<4> {
154
155
  using Type = uint32_t;
  static_assert(sizeof(Type) == 4);
Przemek Tredak's avatar
Przemek Tredak committed
156
157
};

158
template <>
Przemek Tredak's avatar
Przemek Tredak committed
159
struct BytesToType<2> {
160
161
  using Type = uint16_t;
  static_assert(sizeof(Type) == 2);
Przemek Tredak's avatar
Przemek Tredak committed
162
163
};

164
template <>
Przemek Tredak's avatar
Przemek Tredak committed
165
struct BytesToType<1> {
166
167
  using Type = uint8_t;
  static_assert(sizeof(Type) == 1);
Przemek Tredak's avatar
Przemek Tredak committed
168
169
170
171
};

////////////////////////////////////////////////////////////////////////////////////////////////////

172
template <typename T>
Przemek Tredak's avatar
Przemek Tredak committed
173
174
struct TypeToVec2 {};

175
template <>
Przemek Tredak's avatar
Przemek Tredak committed
176
struct TypeToVec2<float> {
177
  using Type = float2;
Przemek Tredak's avatar
Przemek Tredak committed
178
179
};

180
template <>
Przemek Tredak's avatar
Przemek Tredak committed
181
struct TypeToVec2<half> {
182
  using Type = half2;
Przemek Tredak's avatar
Przemek Tredak committed
183
184
};

yuguo's avatar
yuguo committed
185
186
187
188
189
190
#ifdef __HIP_PLATFORM_AMD__
template <>
struct TypeToVec2<__hip_bfloat16> {
  using Type = hip_bfloat16x2;
};
#else
191
template <>
Przemek Tredak's avatar
Przemek Tredak committed
192
struct TypeToVec2<nv_bfloat16> {
193
  using Type = nv_bfloat162;
Przemek Tredak's avatar
Przemek Tredak committed
194
};
yuguo's avatar
yuguo committed
195
#endif
Przemek Tredak's avatar
Przemek Tredak committed
196
197
198

////////////////////////////////////////////////////////////////////////////////////////////////////

199
200
template <typename IType, typename IType2, typename OType, typename CType>
struct CTDBiasDActParam {
201
202
203
204
205
206
207
208
209
210
211
212
213
  using InputType = IType;
  using InputType2 = IType2;
  using OutputType = OType;
  using ComputeType = CType;
  const IType *input;
  const IType2 *act_input;
  OType *output_c;
  OType *output_t;
  const CType *scale_ptr;
  CType *amax;
  CType *scale_inv;
  CType *workspace;
  CType *warp_scales_inv;
214
215
216
217
};

////////////////////////////////////////////////////////////////////////////////////////////////////

218
template <int INDEX>
Przemek Tredak's avatar
Przemek Tredak committed
219
struct Get {
220
221
  template <typename T, typename R>
  static inline __device__ R of(const T &vec);
Przemek Tredak's avatar
Przemek Tredak committed
222
223
};

224
225
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
226
inline __device__ R Get<0>::of(const T &vec) {
227
  return vec.x;
Przemek Tredak's avatar
Przemek Tredak committed
228
229
}

230
231
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
232
inline __device__ R Get<1>::of(const T &vec) {
233
  return vec.y;
Przemek Tredak's avatar
Przemek Tredak committed
234
235
}

236
237
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
238
inline __device__ R Get<2>::of(const T &vec) {
239
  return vec.z;
Przemek Tredak's avatar
Przemek Tredak committed
240
241
}

242
243
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
244
inline __device__ R Get<3>::of(const T &vec) {
245
  return vec.w;
Przemek Tredak's avatar
Przemek Tredak committed
246
247
248
249
}

////////////////////////////////////////////////////////////////////////////////////////////////////

250
251
252
template <typename Src, typename Dst>
struct Converter {
  static inline __device__ Dst convert(const Src &from) { return Dst(from); }
Przemek Tredak's avatar
Przemek Tredak committed
253
254
};

255
256
257
template <>
struct Converter<float2, half2> {
  static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); }
Przemek Tredak's avatar
Przemek Tredak committed
258
259
};

yuguo's avatar
yuguo committed
260
261
262
263
264
265
#ifdef __HIP_PLATFORM_AMD__
template <>
struct Converter<float2, hip_bfloat16x2> {
  static inline __device__ hip_bfloat16x2 convert(const float2 &x) {
    union {
      hip_bfloat16x2 raw;
yuguo's avatar
yuguo committed
266
      __hip_bfloat16 elt[2];
yuguo's avatar
yuguo committed
267
268
269
270
271
272
273
    } tmp;
    tmp.elt[0] = __hip_bfloat16(x.x);
    tmp.elt[1] = __hip_bfloat16(x.y);
    return tmp.raw;
  }
};
#else
274
275
276
template <>
struct Converter<float2, nv_bfloat162> {
  static inline __device__ nv_bfloat162 convert(const float2 &x) {
Przemek Tredak's avatar
Przemek Tredak committed
277
#if __CUDA_ARCH__ >= 800
278
    return __float22bfloat162_rn(x);
Przemek Tredak's avatar
Przemek Tredak committed
279
#else
280
281
282
283
284
285
286
    union {
      nv_bfloat162 raw;
      nv_bfloat16 elt[2];
    } tmp;
    tmp.elt[0] = __float2bfloat16_rn(x.x);
    tmp.elt[1] = __float2bfloat16_rn(x.y);
    return tmp.raw;
Przemek Tredak's avatar
Przemek Tredak committed
287
#endif
288
  }
Przemek Tredak's avatar
Przemek Tredak committed
289
};
yuguo's avatar
yuguo committed
290
#endif
Przemek Tredak's avatar
Przemek Tredak committed
291
292
293

////////////////////////////////////////////////////////////////////////////////////////////////////

294
295
296
template <typename T>
struct Zeros {
  static inline __device__ T get() { return T(0.f); }
Przemek Tredak's avatar
Przemek Tredak committed
297
298
};

299
300
301
template <>
struct Zeros<float2> {
  static inline __device__ float2 get() { return make_float2(0.f, 0.f); }
Przemek Tredak's avatar
Przemek Tredak committed
302
303
304
305
};

////////////////////////////////////////////////////////////////////////////////////////////////////

306
template <typename Elt_type, uint32_t NUM_ELT>
Przemek Tredak's avatar
Przemek Tredak committed
307
struct Vec {
308
  enum { BYTES = NUM_ELT * sizeof(Elt_type) };
Przemek Tredak's avatar
Przemek Tredak committed
309

310
311
  using Vec_type = typename BytesToType<BYTES>::Type;
  using type = Elt_type;
Przemek Tredak's avatar
Przemek Tredak committed
312

313
314
315
316
  using Alias_type = union {
    Vec_type vec;
    Elt_type elt[NUM_ELT];
  };
Przemek Tredak's avatar
Przemek Tredak committed
317

318
  Alias_type data;
yuguo's avatar
yuguo committed
319
320
321
322
323
324
#ifdef __HIP_PLATFORM_AMD__
  __HOST_DEVICE__ Vec& operator=(const Vec& rhs) {
    data.vec = rhs.data.vec;
    return *this;
  }
#endif
Przemek Tredak's avatar
Przemek Tredak committed
325

326
327
328
329
330
  template <typename S>
  inline __device__ void to(Vec<S, NUM_ELT> &other) {  // NOLINT(*)
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      other.data.elt[it] = S(this->data.elt[it]);
Przemek Tredak's avatar
Przemek Tredak committed
331
    }
332
  }
Przemek Tredak's avatar
Przemek Tredak committed
333

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
  template <typename Op>
  inline __device__ void assign(const Op &op) {
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      this->data.elt[it] = op(it);
    }
  }

  // Pointer is cast to vector type
  inline __device__ void load_from(const void *base_ptr, size_t idx = 0) {
    this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
  }

  // Pointer is cast to vector type
  inline __device__ void store_to(void *base_ptr, size_t idx = 0) const {
    static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
  }

  // Pointer is cast to element type. Loads min(count, NUM_ELT)
  // elements and any remaining elements are set to zero.
  inline __device__ void load_from_elts(const void *base_ptr, size_t idx = 0,
                                        size_t count = NUM_ELT) {
    const Elt_type *elt_ptr = static_cast<const Elt_type *>(base_ptr) + idx;
    if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
      for (int it = 0; it < NUM_ELT; it++) {
        this->data.elt[it] = (it < count ? elt_ptr[it] : Elt_type(0.f));
      }
    } else {
      this->load_from(elt_ptr);
    }
  }

  // Pointer is cast to element type. Stores min(count, NUM_ELT)
  // elements.
  inline __device__ void store_to_elts(void *base_ptr, size_t idx = 0,
                                       size_t count = NUM_ELT) const {
    Elt_type *elt_ptr = static_cast<Elt_type *>(base_ptr) + idx;
    if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
      for (int it = 0; it < NUM_ELT; it++) {
        if (it < count) {
          elt_ptr[it] = this->data.elt[it];
Przemek Tredak's avatar
Przemek Tredak committed
377
        }
378
379
380
      }
    } else {
      this->store_to(elt_ptr);
Przemek Tredak's avatar
Przemek Tredak committed
381
    }
382
  }
Przemek Tredak's avatar
Przemek Tredak committed
383

384
385
386
387
  inline __device__ void clear() {
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      this->data.elt[it] = Elt_type(0.f);
Przemek Tredak's avatar
Przemek Tredak committed
388
    }
389
  }
Przemek Tredak's avatar
Przemek Tredak committed
390
391
392
393
394
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct InterCTASync {
395
396
397
398
399
400
401
402
403
404
  inline __device__ InterCTASync(int *barrier, int group, int num_groups, int group_size)
      : phase_counter_(0),
        b0_(barrier + group)  // The barrier for this group of CTAs.
        ,
        b1_(barrier + group + num_groups)  // The barrier for this group of CTAs.
        ,
        group_size_(group_size) {
    // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
  }

yuguo's avatar
yuguo committed
405
406
407
408
409
410
411
412
#ifdef __HIP_PLATFORM_AMD__
  inline __device__ void spin_wait_(int *barrier, int step, int expected) {
    __hip_atomic_fetch_add(barrier, step, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
    for (int found = -1; found != expected; ) {
      found = __hip_atomic_load(barrier, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT);
    }
  }
#else
413
414
415
416
417
418
  inline __device__ void spin_wait_(int *barrier, int step, int expected) {
    asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
    for (int found = -1; found != expected;) {
      asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
    }
  }
yuguo's avatar
yuguo committed
419
#endif
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438

  inline __device__ void sync() {
    // ALL THREADS MUST ENTER!

    // We switch barrier every iteration.
    int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
    // We decrement every other iteration.
    bool dec = phase_counter_ & 0x2;
    int step = dec ? -1 : 1;
    int expected = dec ? 0 : group_size_;
    // There are only 4 phases: up/down for b0/b1.
    phase_counter_ = (phase_counter_ + 1) & 0x3;

    if (threadIdx.x == 0) {
      spin_wait_(barrier, step, expected);
    }
    // CTA waits for thread 0
    __syncthreads();
  }
Przemek Tredak's avatar
Przemek Tredak committed
439

440
441
442
443
  int phase_counter_;
  int *b0_;
  int *b1_;
  int group_size_;
Przemek Tredak's avatar
Przemek Tredak committed
444
445
446
447
};

////////////////////////////////////////////////////////////////////////////////////////////////////

448
template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
449
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
  using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
  using Type = typename Base::Type;

  enum { SMEM_BYTES = Base::SMEM_BYTES };

  enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
  enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };

  // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
  enum {
    WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES
  };

  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {}

  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    data = Base::reduce(data, op);
    // We switch workspace every iteration.
    T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;

    // Warp leaders 0 hold the CTA-local results.
    if (this->warp_n_ == 0 && this->lane_ == 0) {
      workspace[bidn_] = data;
    }
    inter_cta_.sync();
    static_assert(CTAS_PER_ROW <= 32);
    T total = Zeros<T>::get();
    if (this->lane_ < CTAS_PER_ROW) {
      total = workspace[this->lane_];
    }
    total = Reducer<T, 1, 1, 1>::allreduce_(total, op);

    return total;
  }

  InterCTASync inter_cta_;

  T *const w0_;
  T *const w1_;
  int bidn_;
Przemek Tredak's avatar
Przemek Tredak committed
499
500
501
502
};

////////////////////////////////////////////////////////////////////////////////////////////////////

503
template <typename T, uint32_t WARPS_M>
Przemek Tredak's avatar
Przemek Tredak committed
504
struct Reducer<T, 1, WARPS_M, 1> {
505
506
507
  using Type = T;
  enum { SMEM_BYTES = 0 };
  enum { WORKSPACE_BYTES_PER_GROUP = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
508

509
510
511
512
513
514
515
516
517
518
519
520
  enum { THREADS_PER_WARP = 32 };

  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : warp_n_(warp_n), lane_(lane) {}

  template <typename Op>
  static inline __device__ T allreduce_(T data, const Op &op) {
#pragma unroll
    for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
      data = op(data, warp_shuffle_xor(data, it));
Przemek Tredak's avatar
Przemek Tredak committed
521
    }
522
523
    return data;
  }
Przemek Tredak's avatar
Przemek Tredak committed
524

525
526
527
528
529
530
531
532
533
534
535
  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    return allreduce_(data, op);
  }

  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
// only lane 0 holds the result!
#pragma unroll
    for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) {
      data = op(data, warp_shuffle_down(data, it));
Przemek Tredak's avatar
Przemek Tredak committed
536
    }
537
538
539
540
    return data;
  }
  int warp_n_;
  int lane_;
Przemek Tredak's avatar
Przemek Tredak committed
541
542
543
544
};

////////////////////////////////////////////////////////////////////////////////////////////////////

545
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
546
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
547
  using Base = Reducer<T, 1, WARPS_M, 1>;
Przemek Tredak's avatar
Przemek Tredak committed
548

549
  using Type = T;
Przemek Tredak's avatar
Przemek Tredak committed
550

551
552
  enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
  enum { WORKSPACE_BYTES_PER_GROUP = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
553

554
  enum { THREADS_PER_WARP = 32 };
Przemek Tredak's avatar
Przemek Tredak committed
555

556
557
558
559
560
561
562
  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        use0_(true),
        smem0_(&(static_cast<T *>(smem)[warp_m * WARPS_N])),
        smem1_(smem0_ + WARPS_M * WARPS_N) {}
Przemek Tredak's avatar
Przemek Tredak committed
563

564
565
566
567
568
569
570
  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    T *const smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    data = Base::reduce(data, op);
    if (this->lane_ == 0) {
      smem[this->warp_n_] = data;
Przemek Tredak's avatar
Przemek Tredak committed
571
    }
572
573
574
575
576
577
578
579
    __syncthreads();
    T out = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < WARPS_N; it++) {
      out = op(out, smem[it]);
    }
    return out;
  }
Przemek Tredak's avatar
Przemek Tredak committed
580

581
582
583
584
585
586
587
588
  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
    T *const smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    // only intra-CTA group leader holds the result!
    data = Base::reduce(data, op);
    if (this->lane_ == 0) {
      smem[this->warp_n_] = data;
Przemek Tredak's avatar
Przemek Tredak committed
589
    }
590
591
592
593
594
595
596
597
598
599
    __syncthreads();
    T out = Zeros<T>::get();
    if (this->warp_n_ == 0 && this->lane_ == 0) {
#pragma unroll
      for (int it = 0; it < WARPS_N; it++) {
        out = op(out, smem[it]);
      }
    }
    return out;
  }
Przemek Tredak's avatar
Przemek Tredak committed
600

601
602
603
  T *const smem0_;
  T *const smem1_;
  bool use0_;
Przemek Tredak's avatar
Przemek Tredak committed
604
605
606
607
};

////////////////////////////////////////////////////////////////////////////////////////////////////

608
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
609
struct DynamicReducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
  using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
  using Type = typename Base::Type;

  template <typename Params>
  inline __device__ DynamicReducer(const Params &params, uint32_t bidm, uint32_t bidn,
                                   uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        inter_cta_(params.barrier, bidm, params.ctas_per_col, params.ctas_per_row),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * params.ctas_per_row),
        w1_(w0_ + params.ctas_per_col * WARPS_M * params.ctas_per_row) {}

  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    // Trivial case
    if (inter_cta_.group_size_ == 1) {
      return Base::allreduce(data, op);
    }

    data = Base::reduce(data, op);
    // We switch workspace every iteration.
    T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;

    // Warp leaders 0 hold the CTA-local results.
    if (this->warp_n_ == 0 && this->lane_ == 0) {
      workspace[bidn_] = data;
    }
    inter_cta_.sync();
    T total = Zeros<T>::get();
    for (int it = this->lane_; it < inter_cta_.group_size_; it += THREADS_PER_WARP) {
      total = op(total, workspace[it]);
    }
    total = Reducer<T, 1, 1, 1>::allreduce_(total, op);

    return total;
  }

  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
    return allreduce(data, op);
  }

  InterCTASync inter_cta_;

  T *const w0_;
  T *const w1_;
  int bidn_;
Przemek Tredak's avatar
Przemek Tredak committed
658
659
660
661
};

////////////////////////////////////////////////////////////////////////////////////////////////////

662
663
664
665
666
667
668
669
670
671
672
673
674
/*
This is an implementation of the parallel Welford algorithm for incrementally computing variance

This algorithm is known as Chan's update formulae (Chat et al '79):
http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf

An introduction is provided by Wikipedia here:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance?section=5#Parallel_algorithm

A detailed reference on the exact version implemented (with better numerical stability) is provided here:
https://dbs.ifi.uni-heidelberg.de/files/Team/eschubert/publications/SSDBM18-covariance-authorcopy.pdf
*/

675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
template <typename T>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a,
                                             int num_active) {  // NOLINT(*)
  // Assume at least leftmost is valid and
  // init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
  int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);

#pragma unroll
  for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) {
    // Exchange
    T n_b = warp_shuffle_down(n_a, step);
    T m_b = warp_shuffle_down(m_a, step);
    T m2_b = warp_shuffle_down(m2_a, step);

    // Update
    const T n_ab = n_a + n_b;  // We can handle one of them being 0, not both.
    // Might have different n per thread, otherwise this would simplify :(
    const T rn_ab = 1.f / n_ab;
    const T delta = m_a - m_b;
    const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
    const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;

    n_a = n_ab;
    m_a = m_ab;
    m2_a = m2_ab;
  }
  // Intra-warp broadcast (only lane 0 has valid stats).
yuguo's avatar
yuguo committed
702
703
704
705
#ifdef __HIP_PLATFORM_AMD__
  m_a = __shfl(m_a, 0, THREADS_PER_WARP);
  m2_a = __shfl(m2_a, 0, THREADS_PER_WARP);
#else
706
707
  m_a = __shfl_sync(static_cast<uint32_t>(-1), m_a, 0);
  m2_a = __shfl_sync(static_cast<uint32_t>(-1), m2_a, 0);
yuguo's avatar
yuguo committed
708
#endif
Przemek Tredak's avatar
Przemek Tredak committed
709
710
711
712
}

////////////////////////////////////////////////////////////////////////////////////////////////////

713
template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
714
struct Stats {
715
716
  // This could be done generically with the Reducer. But then we
  // would have to exchange 3 instead of 2 fields.
Przemek Tredak's avatar
Przemek Tredak committed
717

718
719
  using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
  using stats_t = typename BlockStats::stats_t;
Przemek Tredak's avatar
Przemek Tredak committed
720

721
  enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
Przemek Tredak's avatar
Przemek Tredak committed
722

723
724
725
726
727
728
729
730
731
732
733
  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
        block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<stats_t *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW),
        warp_n_(warp_n),
        lane_(lane) {}
Przemek Tredak's avatar
Przemek Tredak committed
734

735
736
737
738
739
740
  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
    // TODO(ptredak) rn is not really needed here..
    constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
    stats_t block_stats = block_stats_.compute(elts, block_rn);
Przemek Tredak's avatar
Przemek Tredak committed
741

742
    stats_t *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
Przemek Tredak's avatar
Przemek Tredak committed
743

744
745
    if (warp_n_ == 0 && lane_ == 0) {
      workspace[bidn_] = block_stats;
Przemek Tredak's avatar
Przemek Tredak committed
746
747
    }

748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
    // Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
    inter_cta_.sync();

    T n = Zeros<T>::get();
    T m = Zeros<T>::get();
    T m2 = Zeros<T>::get();

    // Assume CTA group size in N less than 32, such that we can finalize with a single warp.
    static_assert(CTAS_PER_ROW <= 32);

    // Every warp does the final reduction locally.
    if (lane_ < CTAS_PER_ROW) {
      stats_t result = workspace[lane_];
      n = ELTS_PER_ROW_PER_CTA;
      m = transformer_engine::Get<0>::of<stats_t, T>(result);
      m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
    }

    warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);

    return {m, m2};
  }
Przemek Tredak's avatar
Przemek Tredak committed
770

771
772
773
774
775
776
777
778
  InterCTASync inter_cta_;
  BlockStats block_stats_;

  stats_t *const w0_;
  stats_t *const w1_;
  int bidn_;
  int warp_n_;
  int lane_;
Przemek Tredak's avatar
Przemek Tredak committed
779
780
781
782
};

////////////////////////////////////////////////////////////////////////////////////////////////////

783
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
784
struct Stats<T, 1, WARPS_M, WARPS_N> {
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
  using WarpStats = Stats<T, 1, WARPS_M, 1>;
  using stats_t = typename WarpStats::stats_t;

  enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };

  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) {
    smem0_ = static_cast<stats_t *>(smem) + warp_m * WARPS_N;
    smem1_ = smem0_ + WARPS_M * WARPS_N;
  }

  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    stats_t *smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    // Compute warp local for all WARPS_N
    constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
    stats_t warp_stats = warp_stats_.compute(elts, warp_rn);

    // Each warp warp leader stores its stats
    const auto warp_n = warp_stats_.reducer_.warp_n_;
    const auto lane = warp_stats_.reducer_.lane_;
    if (lane == 0) {
      smem[warp_n] = warp_stats;
Przemek Tredak's avatar
Przemek Tredak committed
811
    }
812
    __syncthreads();
Przemek Tredak's avatar
Przemek Tredak committed
813

814
815
816
    T n = Zeros<T>::get();
    T m = Zeros<T>::get();
    T m2 = Zeros<T>::get();
Przemek Tredak's avatar
Przemek Tredak committed
817

818
819
820
821
822
823
824
    // Assume that there are less than 32 warps, such that we can finalize with a single warp
    static_assert(WARPS_N <= 32);
    if (lane < WARPS_N) {
      stats_t result = smem[lane];
      n = N * THREADS_PER_WARP;
      m = transformer_engine::Get<0>::of<stats_t, T>(result);
      m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
Przemek Tredak's avatar
Przemek Tredak committed
825
    }
826
827
828
829
830
831
832
833
834

    warp_chan_upd_dynamic(m, m2, n, WARPS_N);

    return {m, m2};
  }
  WarpStats warp_stats_;
  stats_t *smem0_;
  stats_t *smem1_;
  bool use0_;
Przemek Tredak's avatar
Przemek Tredak committed
835
836
837
838
};

////////////////////////////////////////////////////////////////////////////////////////////////////

839
template <typename T, uint32_t WARPS_M>
Przemek Tredak's avatar
Przemek Tredak committed
840
struct Stats<T, 1, WARPS_M, 1> {
841
842
843
  using stats_t = typename TypeToVec2<T>::Type;
  // The simple Warp reducer.
  using Reducer = Reducer<T, 1, WARPS_M, 1>;
Przemek Tredak's avatar
Przemek Tredak committed
844

845
  enum { SMEM_BYTES = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
846

847
848
849
850
  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {}
Przemek Tredak's avatar
Przemek Tredak committed
851

852
853
854
  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    auto sum = Sum<T>();
Przemek Tredak's avatar
Przemek Tredak committed
855

856
857
858
859
860
861
    T m = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < N; it++) {
      m += elts[it];
    }
    m = reducer_.allreduce(m, sum) * rn;
Przemek Tredak's avatar
Przemek Tredak committed
862

863
864
865
866
867
    T m2 = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < N; it++) {
      T diff = (elts[it] - m);
      m2 += diff * diff;
Przemek Tredak's avatar
Przemek Tredak committed
868
    }
869
870
871
872
    m2 = reducer_.allreduce(m2, sum);

    return {m, m2};
  }
Przemek Tredak's avatar
Przemek Tredak committed
873

874
  Reducer reducer_;
Przemek Tredak's avatar
Przemek Tredak committed
875
876
877
878
879
880
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int num_elems>
__device__ __forceinline__ float warp_reduce_max(const float m) {
881
  float tmp = m;
Przemek Tredak's avatar
Przemek Tredak committed
882
#pragma unroll
883
  for (int delta = num_elems / 2; delta > 0; delta /= 2) {
yuguo's avatar
yuguo committed
884
885
886
#ifdef __HIP_PLATFORM_AMD__
    const float other_m = __shfl_down(tmp, delta, THREADS_PER_WARP);
#else
887
    const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta);
yuguo's avatar
yuguo committed
888
#endif
889
890
891
892
893
    __builtin_assume(tmp >= 0);
    __builtin_assume(other_m >= 0);
    tmp = fmaxf(tmp, other_m);
  }
  return tmp;
Przemek Tredak's avatar
Przemek Tredak committed
894
895
}

896
897
898
899
__forceinline__ __device__ float warp_reduce_max_broadcast(const float val) {
  float val_tmp = val;
#pragma unroll
  for (int offset = THREADS_PER_WARP / 2; offset > 0; offset /= 2) {
yuguo's avatar
yuguo committed
900
901
902
#ifdef __HIP_PLATFORM_AMD__
    const float val_other = __shfl_down(val_tmp, offset, THREADS_PER_WARP);
#else
903
    const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset);
yuguo's avatar
yuguo committed
904
#endif
905
906
907
908
909
910
    __builtin_assume(val_tmp >= 0);
    __builtin_assume(val_other >= 0);
    val_tmp = fmaxf(val_tmp, val_other);
  }
  // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
  constexpr int subwarp_lane_zero = 0;
yuguo's avatar
yuguo committed
911
912
913
#ifdef __HIP_PLATFORM_AMD__
  val_tmp = __shfl(val_tmp, subwarp_lane_zero, THREADS_PER_WARP);
#else
914
  val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero);
yuguo's avatar
yuguo committed
915
#endif
916
917
918
  return val_tmp;
}

Przemek Tredak's avatar
Przemek Tredak committed
919
920
template <int num_warps, typename compute_t>
__device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) {
921
922
923
924
925
926
927
928
  __shared__ float staging[num_warps];
  constexpr int warp_size = 32;
  const float my_max = m;
  const float my_warp_max = warp_reduce_max<warp_size>(my_max);
  if (threadIdx.x % 32 == 0) {
    staging[warpid] = my_warp_max;
  }
  __syncthreads();
929
  compute_t result = 0.f;
930
931
932
933
934
  if (warpid == 0) {
    const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0;
    result = warp_reduce_max<num_warps>(my_max);
  }
  return result;
Przemek Tredak's avatar
Przemek Tredak committed
935
936
}

937
938
939
940
941
942
943
944
945
946
947
948
/**
 * Max reduction in subwarps
 * E.g., if nvec=4, each warp processes 128 elements (32 x 4), that covers four MXFP8 scaling factors.
 * To compute an actual scaling factor for 32 consequentive elements, only 8 threads need to participate,
 * thus splitting the warp into 4x smaller subwarps 8-thread width.
 * 'Butterfly' reduction is used inside subwarps.
 */
template <int subwarp_width>
__forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) {
  float val_tmp = val;
#pragma unroll
  for (int offset = subwarp_width / 2; offset > 0; offset /= 2) {
yuguo's avatar
yuguo committed
949
950
951
#ifdef __HIP_PLATFORM_AMD__
    const float val_other = __shfl_down(val_tmp, offset, subwarp_width);
#else
952
    const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width);
yuguo's avatar
yuguo committed
953
#endif
954
955
956
957
958
959
    __builtin_assume(val_tmp >= 0);
    __builtin_assume(val_other >= 0);
    val_tmp = fmaxf(val_tmp, val_other);
  }
  // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
  constexpr int subwarp_lane_zero = 0;
yuguo's avatar
yuguo committed
960
961
962
#ifdef __HIP_PLATFORM_AMD__
  val_tmp = __shfl(val_tmp, subwarp_lane_zero, subwarp_width);
#else
963
  val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width);
yuguo's avatar
yuguo committed
964
#endif
965
966
967
  return val_tmp;
}

Przemek Tredak's avatar
Przemek Tredak committed
968
// Works only on positive values
969
970
__device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) {
  atomicMax(reinterpret_cast<int *>(addr), __float_as_int(value));
Przemek Tredak's avatar
Przemek Tredak committed
971
972
973
}

// Works only on positive values
974
975
__device__ __forceinline__ void atomicMinFloat(float *addr, const float value) {
  atomicMin(reinterpret_cast<int *>(addr), __float_as_int(value));
Przemek Tredak's avatar
Przemek Tredak committed
976
977
978
}

template <typename T>
979
980
__device__ __forceinline__ void reciprocal(T *value_inv, const T value) {
  *value_inv = 1 / value;
Przemek Tredak's avatar
Przemek Tredak committed
981
982
}

983
984
985
986
987
template <>
__device__ __forceinline__ void reciprocal<float>(float *value_inv, const float value) {
  *value_inv = __frcp_rn(value);
}

988
989
990
991
992
////////////////////////////////////////////////////////////////////////////////////////////////////

using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using e8m0_t = uint8_t;
yuguo's avatar
yuguo committed
993
using int8 = int8_t;
994

995
enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENSIONAL = 2 };
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011

template <typename T>
struct Numeric_Traits;

template <>
struct Numeric_Traits<fp8e4m3> {
  static constexpr int maxUnbiasedExponent = 8;
  static constexpr double maxNorm = 448;
};

template <>
struct Numeric_Traits<fp8e5m2> {
  static constexpr int maxUnbiasedExponent = 15;
  static constexpr double maxNorm = 57344;
};

yuguo's avatar
yuguo committed
1012
1013
1014
1015
1016
1017
template <>
struct Numeric_Traits<int8> {
  static constexpr int maxUnbiasedExponent = 0;
  static constexpr double maxNorm = 127;
};

1018
1019
1020
1021
1022
1023
1024
1025
1026
template <typename T>
struct Quantized_Limits {
  static constexpr int max_unbiased_exponent = Numeric_Traits<T>::maxUnbiasedExponent;
  static constexpr float max_norm = Numeric_Traits<T>::maxNorm;
  static constexpr float max_norm_rcp = 1.0 / max_norm;
  static constexpr float emax = 1 << max_unbiased_exponent;
  static constexpr float emax_rcp = 1.0 / emax;
};

Przemek Tredak's avatar
Przemek Tredak committed
1027
1028
1029
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_