utils.cuh 29.3 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
11
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
#define TRANSFORMER_ENGINE_COMMON_UTILS_CUH_

#include <cuda_bf16.h>
#include <cuda_fp16.h>
Tim Moon's avatar
Tim Moon committed
12
#include <cuda_fp8.h>
Przemek Tredak's avatar
Przemek Tredak committed
13

14
15
16
17
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif

Tim Moon's avatar
Tim Moon committed
18
#if !defined(__CUDACC_RTC__)
19
#include <cassert>
Tim Moon's avatar
Tim Moon committed
20
21
22
23
24
25
26
27
28
29
30
31
#include <cstdint>
#else
// Importing C++ standard headers is a pain with NVRTC
using uint8_t = unsigned char;
using uint16_t = unsigned short int;  // NOLINT(*)
using uint32_t = unsigned int;
using uint64_t = unsigned long long int;  // NOLINT(*)
static_assert(sizeof(uint8_t) == 1);
static_assert(sizeof(uint16_t) == 2);
static_assert(sizeof(uint32_t) == 4);
static_assert(sizeof(uint64_t) == 8);
#endif
Przemek Tredak's avatar
Przemek Tredak committed
32
33
34
35
36
37
38

////////////////////////////////////////////////////////////////////////////////////////////////////

constexpr uint32_t THREADS_PER_WARP = 32;

////////////////////////////////////////////////////////////////////////////////////////////////////

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
// Device-side error
#define NVTE_DEVICE_ERROR(message)                                                                 \
  do {                                                                                             \
    printf("%s:%d in function %s (thread (%d,%d,%d), block (%d,%d,%d)): %s\n", __FILE__, __LINE__, \
           __func__, threadIdx.x, threadIdx.y, threadIdx.z, blockIdx.x, blockIdx.y, blockIdx.z,    \
           (message));                                                                             \
    assert(0);                                                                                     \
  } while (false)

// Device-side error on thread 0
#define NVTE_DEVICE_THREAD0_ERROR(message)                                           \
  do {                                                                               \
    if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && \
        threadIdx.y == 0 && threadIdx.z == 0) {                                      \
      NVTE_DEVICE_ERROR(message);                                                    \
    }                                                                                \
  } while (false)

////////////////////////////////////////////////////////////////////////////////////////////////////

59
60
inline __device__ float2 operator+(const float2 &a, const float2 &b) {  // NOLINT(*)
  return {a.x + b.x, a.y + b.y};
Przemek Tredak's avatar
Przemek Tredak committed
61
62
63
64
}

////////////////////////////////////////////////////////////////////////////////////////////////////

65
66
67
inline __device__ void operator+=(float2 &a, const float2 &b) {  // NOLINT(*)
  a.x += b.x;
  a.y += b.y;
Przemek Tredak's avatar
Przemek Tredak committed
68
69
70
71
}

////////////////////////////////////////////////////////////////////////////////////////////////////

72
template <typename T>
Przemek Tredak's avatar
Przemek Tredak committed
73
struct Sum {
74
75
  inline __device__ Sum() {}
  inline __device__ T operator()(const T &a, const T &b) const { return a + b; }
Przemek Tredak's avatar
Przemek Tredak committed
76
77
78
79
};

////////////////////////////////////////////////////////////////////////////////////////////////////

80
81
82
template <typename T>
inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) {
  return __shfl_xor_sync(static_cast<uint32_t>(-1), x, idx);
Przemek Tredak's avatar
Przemek Tredak committed
83
84
}

85
86
87
template <>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 &x, uint32_t idx) {
  return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)};
Przemek Tredak's avatar
Przemek Tredak committed
88
89
}

90
91
92
template <typename T>
inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) {
  return __shfl_down_sync(static_cast<uint32_t>(-1), x, idx);
Przemek Tredak's avatar
Przemek Tredak committed
93
94
}

95
96
97
template <>
inline __device__ float2 warp_shuffle_down<float2>(const float2 &x, uint32_t idx) {
  return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)};
Przemek Tredak's avatar
Przemek Tredak committed
98
99
100
101
102
103
104
105
106
}

////////////////////////////////////////////////////////////////////////////////////////////////////

namespace transformer_engine {

////////////////////////////////////////////////////////////////////////////////////////////////////

struct uint16 {
107
108
109
110
  uint4 u;
  uint4 v;
  uint4 s;
  uint4 t;
Przemek Tredak's avatar
Przemek Tredak committed
111
112
113
114
115
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct uint8 {
116
117
  uint4 u;
  uint4 v;
Przemek Tredak's avatar
Przemek Tredak committed
118
119
120
121
};

////////////////////////////////////////////////////////////////////////////////////////////////////

122
template <int BYTES>
Przemek Tredak's avatar
Przemek Tredak committed
123
124
struct BytesToType {};

125
template <>
Przemek Tredak's avatar
Przemek Tredak committed
126
struct BytesToType<64> {
127
128
  using Type = uint16;
  static_assert(sizeof(Type) == 64);
Przemek Tredak's avatar
Przemek Tredak committed
129
130
};

131
template <>
Przemek Tredak's avatar
Przemek Tredak committed
132
struct BytesToType<32> {
133
134
  using Type = uint8;
  static_assert(sizeof(Type) == 32);
Przemek Tredak's avatar
Przemek Tredak committed
135
136
};

137
template <>
Przemek Tredak's avatar
Przemek Tredak committed
138
struct BytesToType<16> {
139
140
  using Type = uint4;
  static_assert(sizeof(Type) == 16);
Przemek Tredak's avatar
Przemek Tredak committed
141
142
};

143
template <>
Przemek Tredak's avatar
Przemek Tredak committed
144
struct BytesToType<8> {
145
146
  using Type = uint64_t;
  static_assert(sizeof(Type) == 8);
Przemek Tredak's avatar
Przemek Tredak committed
147
148
};

149
template <>
Przemek Tredak's avatar
Przemek Tredak committed
150
struct BytesToType<4> {
151
152
  using Type = uint32_t;
  static_assert(sizeof(Type) == 4);
Przemek Tredak's avatar
Przemek Tredak committed
153
154
};

155
template <>
Przemek Tredak's avatar
Przemek Tredak committed
156
struct BytesToType<2> {
157
158
  using Type = uint16_t;
  static_assert(sizeof(Type) == 2);
Przemek Tredak's avatar
Przemek Tredak committed
159
160
};

161
template <>
Przemek Tredak's avatar
Przemek Tredak committed
162
struct BytesToType<1> {
163
164
  using Type = uint8_t;
  static_assert(sizeof(Type) == 1);
Przemek Tredak's avatar
Przemek Tredak committed
165
166
167
168
};

////////////////////////////////////////////////////////////////////////////////////////////////////

169
template <typename T>
Przemek Tredak's avatar
Przemek Tredak committed
170
171
struct TypeToVec2 {};

172
template <>
Przemek Tredak's avatar
Przemek Tredak committed
173
struct TypeToVec2<float> {
174
  using Type = float2;
Przemek Tredak's avatar
Przemek Tredak committed
175
176
};

177
template <>
Przemek Tredak's avatar
Przemek Tredak committed
178
struct TypeToVec2<half> {
179
  using Type = half2;
Przemek Tredak's avatar
Przemek Tredak committed
180
181
};

182
template <>
Przemek Tredak's avatar
Przemek Tredak committed
183
struct TypeToVec2<nv_bfloat16> {
184
  using Type = nv_bfloat162;
Przemek Tredak's avatar
Przemek Tredak committed
185
186
187
188
};

////////////////////////////////////////////////////////////////////////////////////////////////////

189
190
template <typename IType, typename IType2, typename OType, typename CType>
struct CTDBiasDActParam {
191
192
193
194
195
196
197
198
199
200
201
202
203
  using InputType = IType;
  using InputType2 = IType2;
  using OutputType = OType;
  using ComputeType = CType;
  const IType *input;
  const IType2 *act_input;
  OType *output_c;
  OType *output_t;
  const CType *scale_ptr;
  CType *amax;
  CType *scale_inv;
  CType *workspace;
  CType *warp_scales_inv;
204
205
206
207
};

////////////////////////////////////////////////////////////////////////////////////////////////////

208
template <int INDEX>
Przemek Tredak's avatar
Przemek Tredak committed
209
struct Get {
210
211
  template <typename T, typename R>
  static inline __device__ R of(const T &vec);
Przemek Tredak's avatar
Przemek Tredak committed
212
213
};

214
215
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
216
inline __device__ R Get<0>::of(const T &vec) {
217
  return vec.x;
Przemek Tredak's avatar
Przemek Tredak committed
218
219
}

220
221
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
222
inline __device__ R Get<1>::of(const T &vec) {
223
  return vec.y;
Przemek Tredak's avatar
Przemek Tredak committed
224
225
}

226
227
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
228
inline __device__ R Get<2>::of(const T &vec) {
229
  return vec.z;
Przemek Tredak's avatar
Przemek Tredak committed
230
231
}

232
233
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
234
inline __device__ R Get<3>::of(const T &vec) {
235
  return vec.w;
Przemek Tredak's avatar
Przemek Tredak committed
236
237
238
239
}

////////////////////////////////////////////////////////////////////////////////////////////////////

240
241
242
template <typename Src, typename Dst>
struct Converter {
  static inline __device__ Dst convert(const Src &from) { return Dst(from); }
Przemek Tredak's avatar
Przemek Tredak committed
243
244
};

245
246
247
template <>
struct Converter<float2, half2> {
  static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); }
Przemek Tredak's avatar
Przemek Tredak committed
248
249
};

250
251
252
template <>
struct Converter<float2, nv_bfloat162> {
  static inline __device__ nv_bfloat162 convert(const float2 &x) {
Przemek Tredak's avatar
Przemek Tredak committed
253
#if __CUDA_ARCH__ >= 800
254
    return __float22bfloat162_rn(x);
Przemek Tredak's avatar
Przemek Tredak committed
255
#else
256
257
258
259
260
261
262
    union {
      nv_bfloat162 raw;
      nv_bfloat16 elt[2];
    } tmp;
    tmp.elt[0] = __float2bfloat16_rn(x.x);
    tmp.elt[1] = __float2bfloat16_rn(x.y);
    return tmp.raw;
Przemek Tredak's avatar
Przemek Tredak committed
263
#endif
264
  }
Przemek Tredak's avatar
Przemek Tredak committed
265
266
267
268
};

////////////////////////////////////////////////////////////////////////////////////////////////////

269
270
271
template <typename T>
struct Zeros {
  static inline __device__ T get() { return T(0.f); }
Przemek Tredak's avatar
Przemek Tredak committed
272
273
};

274
275
276
template <>
struct Zeros<float2> {
  static inline __device__ float2 get() { return make_float2(0.f, 0.f); }
Przemek Tredak's avatar
Przemek Tredak committed
277
278
279
280
};

////////////////////////////////////////////////////////////////////////////////////////////////////

281
template <typename Elt_type, uint32_t NUM_ELT>
Przemek Tredak's avatar
Przemek Tredak committed
282
struct Vec {
283
  enum { BYTES = NUM_ELT * sizeof(Elt_type) };
Przemek Tredak's avatar
Przemek Tredak committed
284

285
286
  using Vec_type = typename BytesToType<BYTES>::Type;
  using type = Elt_type;
Przemek Tredak's avatar
Przemek Tredak committed
287

288
289
290
291
  using Alias_type = union {
    Vec_type vec;
    Elt_type elt[NUM_ELT];
  };
Przemek Tredak's avatar
Przemek Tredak committed
292

293
  Alias_type data;
Przemek Tredak's avatar
Przemek Tredak committed
294

295
296
297
298
299
  template <typename S>
  inline __device__ void to(Vec<S, NUM_ELT> &other) {  // NOLINT(*)
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      other.data.elt[it] = S(this->data.elt[it]);
Przemek Tredak's avatar
Przemek Tredak committed
300
    }
301
  }
Przemek Tredak's avatar
Przemek Tredak committed
302

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
  template <typename Op>
  inline __device__ void assign(const Op &op) {
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      this->data.elt[it] = op(it);
    }
  }

  // Pointer is cast to vector type
  inline __device__ void load_from(const void *base_ptr, size_t idx = 0) {
    this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
  }

  // Pointer is cast to vector type
  inline __device__ void store_to(void *base_ptr, size_t idx = 0) const {
    static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
  }

  // Pointer is cast to element type. Loads min(count, NUM_ELT)
  // elements and any remaining elements are set to zero.
  inline __device__ void load_from_elts(const void *base_ptr, size_t idx = 0,
                                        size_t count = NUM_ELT) {
    const Elt_type *elt_ptr = static_cast<const Elt_type *>(base_ptr) + idx;
    if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
      for (int it = 0; it < NUM_ELT; it++) {
        this->data.elt[it] = (it < count ? elt_ptr[it] : Elt_type(0.f));
      }
    } else {
      this->load_from(elt_ptr);
    }
  }

  // Pointer is cast to element type. Stores min(count, NUM_ELT)
  // elements.
  inline __device__ void store_to_elts(void *base_ptr, size_t idx = 0,
                                       size_t count = NUM_ELT) const {
    Elt_type *elt_ptr = static_cast<Elt_type *>(base_ptr) + idx;
    if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
      for (int it = 0; it < NUM_ELT; it++) {
        if (it < count) {
          elt_ptr[it] = this->data.elt[it];
Przemek Tredak's avatar
Przemek Tredak committed
346
        }
347
348
349
      }
    } else {
      this->store_to(elt_ptr);
Przemek Tredak's avatar
Przemek Tredak committed
350
    }
351
  }
Przemek Tredak's avatar
Przemek Tredak committed
352

353
354
355
356
  inline __device__ void clear() {
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      this->data.elt[it] = Elt_type(0.f);
Przemek Tredak's avatar
Przemek Tredak committed
357
    }
358
  }
Przemek Tredak's avatar
Przemek Tredak committed
359
360
361
362
363
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct InterCTASync {
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
  inline __device__ InterCTASync(int *barrier, int group, int num_groups, int group_size)
      : phase_counter_(0),
        b0_(barrier + group)  // The barrier for this group of CTAs.
        ,
        b1_(barrier + group + num_groups)  // The barrier for this group of CTAs.
        ,
        group_size_(group_size) {
    // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
  }

  inline __device__ void spin_wait_(int *barrier, int step, int expected) {
    asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
    for (int found = -1; found != expected;) {
      asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
    }
  }

  inline __device__ void sync() {
    // ALL THREADS MUST ENTER!

    // We switch barrier every iteration.
    int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
    // We decrement every other iteration.
    bool dec = phase_counter_ & 0x2;
    int step = dec ? -1 : 1;
    int expected = dec ? 0 : group_size_;
    // There are only 4 phases: up/down for b0/b1.
    phase_counter_ = (phase_counter_ + 1) & 0x3;

    if (threadIdx.x == 0) {
      spin_wait_(barrier, step, expected);
    }
    // CTA waits for thread 0
    __syncthreads();
  }
Przemek Tredak's avatar
Przemek Tredak committed
399

400
401
402
403
  int phase_counter_;
  int *b0_;
  int *b1_;
  int group_size_;
Przemek Tredak's avatar
Przemek Tredak committed
404
405
406
407
};

////////////////////////////////////////////////////////////////////////////////////////////////////

408
template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
409
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
  using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
  using Type = typename Base::Type;

  enum { SMEM_BYTES = Base::SMEM_BYTES };

  enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
  enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };

  // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
  enum {
    WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES
  };

  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {}

  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    data = Base::reduce(data, op);
    // We switch workspace every iteration.
    T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;

    // Warp leaders 0 hold the CTA-local results.
    if (this->warp_n_ == 0 && this->lane_ == 0) {
      workspace[bidn_] = data;
    }
    inter_cta_.sync();
    static_assert(CTAS_PER_ROW <= 32);
    T total = Zeros<T>::get();
    if (this->lane_ < CTAS_PER_ROW) {
      total = workspace[this->lane_];
    }
    total = Reducer<T, 1, 1, 1>::allreduce_(total, op);

    return total;
  }

  InterCTASync inter_cta_;

  T *const w0_;
  T *const w1_;
  int bidn_;
Przemek Tredak's avatar
Przemek Tredak committed
459
460
461
462
};

////////////////////////////////////////////////////////////////////////////////////////////////////

463
template <typename T, uint32_t WARPS_M>
Przemek Tredak's avatar
Przemek Tredak committed
464
struct Reducer<T, 1, WARPS_M, 1> {
465
466
467
  using Type = T;
  enum { SMEM_BYTES = 0 };
  enum { WORKSPACE_BYTES_PER_GROUP = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
468

469
470
471
472
473
474
475
476
477
478
479
480
  enum { THREADS_PER_WARP = 32 };

  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : warp_n_(warp_n), lane_(lane) {}

  template <typename Op>
  static inline __device__ T allreduce_(T data, const Op &op) {
#pragma unroll
    for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
      data = op(data, warp_shuffle_xor(data, it));
Przemek Tredak's avatar
Przemek Tredak committed
481
    }
482
483
    return data;
  }
Przemek Tredak's avatar
Przemek Tredak committed
484

485
486
487
488
489
490
491
492
493
494
495
  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    return allreduce_(data, op);
  }

  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
// only lane 0 holds the result!
#pragma unroll
    for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) {
      data = op(data, warp_shuffle_down(data, it));
Przemek Tredak's avatar
Przemek Tredak committed
496
    }
497
498
499
500
    return data;
  }
  int warp_n_;
  int lane_;
Przemek Tredak's avatar
Przemek Tredak committed
501
502
503
504
};

////////////////////////////////////////////////////////////////////////////////////////////////////

505
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
506
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
507
  using Base = Reducer<T, 1, WARPS_M, 1>;
Przemek Tredak's avatar
Przemek Tredak committed
508

509
  using Type = T;
Przemek Tredak's avatar
Przemek Tredak committed
510

511
512
  enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
  enum { WORKSPACE_BYTES_PER_GROUP = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
513

514
  enum { THREADS_PER_WARP = 32 };
Przemek Tredak's avatar
Przemek Tredak committed
515

516
517
518
519
520
521
522
  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        use0_(true),
        smem0_(&(static_cast<T *>(smem)[warp_m * WARPS_N])),
        smem1_(smem0_ + WARPS_M * WARPS_N) {}
Przemek Tredak's avatar
Przemek Tredak committed
523

524
525
526
527
528
529
530
  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    T *const smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    data = Base::reduce(data, op);
    if (this->lane_ == 0) {
      smem[this->warp_n_] = data;
Przemek Tredak's avatar
Przemek Tredak committed
531
    }
532
533
534
535
536
537
538
539
    __syncthreads();
    T out = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < WARPS_N; it++) {
      out = op(out, smem[it]);
    }
    return out;
  }
Przemek Tredak's avatar
Przemek Tredak committed
540

541
542
543
544
545
546
547
548
  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
    T *const smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    // only intra-CTA group leader holds the result!
    data = Base::reduce(data, op);
    if (this->lane_ == 0) {
      smem[this->warp_n_] = data;
Przemek Tredak's avatar
Przemek Tredak committed
549
    }
550
551
552
553
554
555
556
557
558
559
    __syncthreads();
    T out = Zeros<T>::get();
    if (this->warp_n_ == 0 && this->lane_ == 0) {
#pragma unroll
      for (int it = 0; it < WARPS_N; it++) {
        out = op(out, smem[it]);
      }
    }
    return out;
  }
Przemek Tredak's avatar
Przemek Tredak committed
560

561
562
563
  T *const smem0_;
  T *const smem1_;
  bool use0_;
Przemek Tredak's avatar
Przemek Tredak committed
564
565
566
567
};

////////////////////////////////////////////////////////////////////////////////////////////////////

568
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
569
struct DynamicReducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
  using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
  using Type = typename Base::Type;

  template <typename Params>
  inline __device__ DynamicReducer(const Params &params, uint32_t bidm, uint32_t bidn,
                                   uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        inter_cta_(params.barrier, bidm, params.ctas_per_col, params.ctas_per_row),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * params.ctas_per_row),
        w1_(w0_ + params.ctas_per_col * WARPS_M * params.ctas_per_row) {}

  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    // Trivial case
    if (inter_cta_.group_size_ == 1) {
      return Base::allreduce(data, op);
    }

    data = Base::reduce(data, op);
    // We switch workspace every iteration.
    T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;

    // Warp leaders 0 hold the CTA-local results.
    if (this->warp_n_ == 0 && this->lane_ == 0) {
      workspace[bidn_] = data;
    }
    inter_cta_.sync();
    T total = Zeros<T>::get();
    for (int it = this->lane_; it < inter_cta_.group_size_; it += THREADS_PER_WARP) {
      total = op(total, workspace[it]);
    }
    total = Reducer<T, 1, 1, 1>::allreduce_(total, op);

    return total;
  }

  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
    return allreduce(data, op);
  }

  InterCTASync inter_cta_;

  T *const w0_;
  T *const w1_;
  int bidn_;
Przemek Tredak's avatar
Przemek Tredak committed
618
619
620
621
};

////////////////////////////////////////////////////////////////////////////////////////////////////

622
623
624
625
626
627
628
629
630
631
632
633
634
/*
This is an implementation of the parallel Welford algorithm for incrementally computing variance

This algorithm is known as Chan's update formulae (Chat et al '79):
http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf

An introduction is provided by Wikipedia here:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance?section=5#Parallel_algorithm

A detailed reference on the exact version implemented (with better numerical stability) is provided here:
https://dbs.ifi.uni-heidelberg.de/files/Team/eschubert/publications/SSDBM18-covariance-authorcopy.pdf
*/

635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
template <typename T>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a,
                                             int num_active) {  // NOLINT(*)
  // Assume at least leftmost is valid and
  // init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
  int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);

#pragma unroll
  for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) {
    // Exchange
    T n_b = warp_shuffle_down(n_a, step);
    T m_b = warp_shuffle_down(m_a, step);
    T m2_b = warp_shuffle_down(m2_a, step);

    // Update
    const T n_ab = n_a + n_b;  // We can handle one of them being 0, not both.
    // Might have different n per thread, otherwise this would simplify :(
    const T rn_ab = 1.f / n_ab;
    const T delta = m_a - m_b;
    const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
    const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;

    n_a = n_ab;
    m_a = m_ab;
    m2_a = m2_ab;
  }
  // Intra-warp broadcast (only lane 0 has valid stats).
  m_a = __shfl_sync(static_cast<uint32_t>(-1), m_a, 0);
  m2_a = __shfl_sync(static_cast<uint32_t>(-1), m2_a, 0);
Przemek Tredak's avatar
Przemek Tredak committed
664
665
666
667
}

////////////////////////////////////////////////////////////////////////////////////////////////////

668
template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
669
struct Stats {
670
671
  // This could be done generically with the Reducer. But then we
  // would have to exchange 3 instead of 2 fields.
Przemek Tredak's avatar
Przemek Tredak committed
672

673
674
  using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
  using stats_t = typename BlockStats::stats_t;
Przemek Tredak's avatar
Przemek Tredak committed
675

676
  enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
Przemek Tredak's avatar
Przemek Tredak committed
677

678
679
680
681
682
683
684
685
686
687
688
  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
        block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<stats_t *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW),
        warp_n_(warp_n),
        lane_(lane) {}
Przemek Tredak's avatar
Przemek Tredak committed
689

690
691
692
693
694
695
  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
    // TODO(ptredak) rn is not really needed here..
    constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
    stats_t block_stats = block_stats_.compute(elts, block_rn);
Przemek Tredak's avatar
Przemek Tredak committed
696

697
    stats_t *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
Przemek Tredak's avatar
Przemek Tredak committed
698

699
700
    if (warp_n_ == 0 && lane_ == 0) {
      workspace[bidn_] = block_stats;
Przemek Tredak's avatar
Przemek Tredak committed
701
702
    }

703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
    // Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
    inter_cta_.sync();

    T n = Zeros<T>::get();
    T m = Zeros<T>::get();
    T m2 = Zeros<T>::get();

    // Assume CTA group size in N less than 32, such that we can finalize with a single warp.
    static_assert(CTAS_PER_ROW <= 32);

    // Every warp does the final reduction locally.
    if (lane_ < CTAS_PER_ROW) {
      stats_t result = workspace[lane_];
      n = ELTS_PER_ROW_PER_CTA;
      m = transformer_engine::Get<0>::of<stats_t, T>(result);
      m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
    }

    warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);

    return {m, m2};
  }
Przemek Tredak's avatar
Przemek Tredak committed
725

726
727
728
729
730
731
732
733
  InterCTASync inter_cta_;
  BlockStats block_stats_;

  stats_t *const w0_;
  stats_t *const w1_;
  int bidn_;
  int warp_n_;
  int lane_;
Przemek Tredak's avatar
Przemek Tredak committed
734
735
736
737
};

////////////////////////////////////////////////////////////////////////////////////////////////////

738
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
739
struct Stats<T, 1, WARPS_M, WARPS_N> {
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
  using WarpStats = Stats<T, 1, WARPS_M, 1>;
  using stats_t = typename WarpStats::stats_t;

  enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };

  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) {
    smem0_ = static_cast<stats_t *>(smem) + warp_m * WARPS_N;
    smem1_ = smem0_ + WARPS_M * WARPS_N;
  }

  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    stats_t *smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    // Compute warp local for all WARPS_N
    constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
    stats_t warp_stats = warp_stats_.compute(elts, warp_rn);

    // Each warp warp leader stores its stats
    const auto warp_n = warp_stats_.reducer_.warp_n_;
    const auto lane = warp_stats_.reducer_.lane_;
    if (lane == 0) {
      smem[warp_n] = warp_stats;
Przemek Tredak's avatar
Przemek Tredak committed
766
    }
767
    __syncthreads();
Przemek Tredak's avatar
Przemek Tredak committed
768

769
770
771
    T n = Zeros<T>::get();
    T m = Zeros<T>::get();
    T m2 = Zeros<T>::get();
Przemek Tredak's avatar
Przemek Tredak committed
772

773
774
775
776
777
778
779
    // Assume that there are less than 32 warps, such that we can finalize with a single warp
    static_assert(WARPS_N <= 32);
    if (lane < WARPS_N) {
      stats_t result = smem[lane];
      n = N * THREADS_PER_WARP;
      m = transformer_engine::Get<0>::of<stats_t, T>(result);
      m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
Przemek Tredak's avatar
Przemek Tredak committed
780
    }
781
782
783
784
785
786
787
788
789

    warp_chan_upd_dynamic(m, m2, n, WARPS_N);

    return {m, m2};
  }
  WarpStats warp_stats_;
  stats_t *smem0_;
  stats_t *smem1_;
  bool use0_;
Przemek Tredak's avatar
Przemek Tredak committed
790
791
792
793
};

////////////////////////////////////////////////////////////////////////////////////////////////////

794
template <typename T, uint32_t WARPS_M>
Przemek Tredak's avatar
Przemek Tredak committed
795
struct Stats<T, 1, WARPS_M, 1> {
796
797
798
  using stats_t = typename TypeToVec2<T>::Type;
  // The simple Warp reducer.
  using Reducer = Reducer<T, 1, WARPS_M, 1>;
Przemek Tredak's avatar
Przemek Tredak committed
799

800
  enum { SMEM_BYTES = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
801

802
803
804
805
  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {}
Przemek Tredak's avatar
Przemek Tredak committed
806

807
808
809
  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    auto sum = Sum<T>();
Przemek Tredak's avatar
Przemek Tredak committed
810

811
812
813
814
815
816
    T m = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < N; it++) {
      m += elts[it];
    }
    m = reducer_.allreduce(m, sum) * rn;
Przemek Tredak's avatar
Przemek Tredak committed
817

818
819
820
821
822
    T m2 = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < N; it++) {
      T diff = (elts[it] - m);
      m2 += diff * diff;
Przemek Tredak's avatar
Przemek Tredak committed
823
    }
824
825
826
827
    m2 = reducer_.allreduce(m2, sum);

    return {m, m2};
  }
Przemek Tredak's avatar
Przemek Tredak committed
828

829
  Reducer reducer_;
Przemek Tredak's avatar
Przemek Tredak committed
830
831
832
833
834
835
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int num_elems>
__device__ __forceinline__ float warp_reduce_max(const float m) {
836
  float tmp = m;
Przemek Tredak's avatar
Przemek Tredak committed
837
#pragma unroll
838
839
840
841
842
843
844
  for (int delta = num_elems / 2; delta > 0; delta /= 2) {
    const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta);
    __builtin_assume(tmp >= 0);
    __builtin_assume(other_m >= 0);
    tmp = fmaxf(tmp, other_m);
  }
  return tmp;
Przemek Tredak's avatar
Przemek Tredak committed
845
846
}

847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
__forceinline__ __device__ float warp_reduce_max_broadcast(const float val) {
  float val_tmp = val;
#pragma unroll
  for (int offset = THREADS_PER_WARP / 2; offset > 0; offset /= 2) {
    const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset);
    __builtin_assume(val_tmp >= 0);
    __builtin_assume(val_other >= 0);
    val_tmp = fmaxf(val_tmp, val_other);
  }
  // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
  constexpr int subwarp_lane_zero = 0;
  val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero);
  return val_tmp;
}

Przemek Tredak's avatar
Przemek Tredak committed
862
863
template <int num_warps, typename compute_t>
__device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) {
864
865
866
867
868
869
870
871
  __shared__ float staging[num_warps];
  constexpr int warp_size = 32;
  const float my_max = m;
  const float my_warp_max = warp_reduce_max<warp_size>(my_max);
  if (threadIdx.x % 32 == 0) {
    staging[warpid] = my_warp_max;
  }
  __syncthreads();
872
  compute_t result = 0.f;
873
874
875
876
877
  if (warpid == 0) {
    const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0;
    result = warp_reduce_max<num_warps>(my_max);
  }
  return result;
Przemek Tredak's avatar
Przemek Tredak committed
878
879
}

880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
/**
 * Max reduction in subwarps
 * E.g., if nvec=4, each warp processes 128 elements (32 x 4), that covers four MXFP8 scaling factors.
 * To compute an actual scaling factor for 32 consequentive elements, only 8 threads need to participate,
 * thus splitting the warp into 4x smaller subwarps 8-thread width.
 * 'Butterfly' reduction is used inside subwarps.
 */
template <int subwarp_width>
__forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) {
  float val_tmp = val;
#pragma unroll
  for (int offset = subwarp_width / 2; offset > 0; offset /= 2) {
    const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width);
    __builtin_assume(val_tmp >= 0);
    __builtin_assume(val_other >= 0);
    val_tmp = fmaxf(val_tmp, val_other);
  }
  // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
  constexpr int subwarp_lane_zero = 0;
  val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width);
  return val_tmp;
}

Przemek Tredak's avatar
Przemek Tredak committed
903
// Works only on positive values
904
905
__device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) {
  atomicMax(reinterpret_cast<int *>(addr), __float_as_int(value));
Przemek Tredak's avatar
Przemek Tredak committed
906
907
908
}

// Works only on positive values
909
910
__device__ __forceinline__ void atomicMinFloat(float *addr, const float value) {
  atomicMin(reinterpret_cast<int *>(addr), __float_as_int(value));
Przemek Tredak's avatar
Przemek Tredak committed
911
912
913
}

template <typename T>
914
915
__device__ __forceinline__ void reciprocal(T *value_inv, const T value) {
  *value_inv = 1 / value;
Przemek Tredak's avatar
Przemek Tredak committed
916
917
}

918
919
920
921
922
template <>
__device__ __forceinline__ void reciprocal<float>(float *value_inv, const float value) {
  *value_inv = __frcp_rn(value);
}

923
924
925
926
927
928
////////////////////////////////////////////////////////////////////////////////////////////////////

using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using e8m0_t = uint8_t;

929
enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENSIONAL = 2 };
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954

template <typename T>
struct Numeric_Traits;

template <>
struct Numeric_Traits<fp8e4m3> {
  static constexpr int maxUnbiasedExponent = 8;
  static constexpr double maxNorm = 448;
};

template <>
struct Numeric_Traits<fp8e5m2> {
  static constexpr int maxUnbiasedExponent = 15;
  static constexpr double maxNorm = 57344;
};

template <typename T>
struct Quantized_Limits {
  static constexpr int max_unbiased_exponent = Numeric_Traits<T>::maxUnbiasedExponent;
  static constexpr float max_norm = Numeric_Traits<T>::maxNorm;
  static constexpr float max_norm_rcp = 1.0 / max_norm;
  static constexpr float emax = 1 << max_unbiased_exponent;
  static constexpr float emax_rcp = 1.0 / emax;
};

Przemek Tredak's avatar
Przemek Tredak committed
955
956
957
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_