utils.cuh 29.1 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
11
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
#define TRANSFORMER_ENGINE_COMMON_UTILS_CUH_

#include <cuda_bf16.h>
#include <cuda_fp16.h>
Tim Moon's avatar
Tim Moon committed
12
#include <cuda_fp8.h>
Przemek Tredak's avatar
Przemek Tredak committed
13

Tim Moon's avatar
Tim Moon committed
14
15
16
17
18
19
20
21
22
23
24
25
26
#if !defined(__CUDACC_RTC__)
#include <cstdint>
#else
// Importing C++ standard headers is a pain with NVRTC
using uint8_t = unsigned char;
using uint16_t = unsigned short int;  // NOLINT(*)
using uint32_t = unsigned int;
using uint64_t = unsigned long long int;  // NOLINT(*)
static_assert(sizeof(uint8_t) == 1);
static_assert(sizeof(uint16_t) == 2);
static_assert(sizeof(uint32_t) == 4);
static_assert(sizeof(uint64_t) == 8);
#endif
Przemek Tredak's avatar
Przemek Tredak committed
27
28
29
30
31
32
33

////////////////////////////////////////////////////////////////////////////////////////////////////

constexpr uint32_t THREADS_PER_WARP = 32;

////////////////////////////////////////////////////////////////////////////////////////////////////

34
35
inline __device__ float2 operator+(const float2 &a, const float2 &b) {  // NOLINT(*)
  return {a.x + b.x, a.y + b.y};
Przemek Tredak's avatar
Przemek Tredak committed
36
37
38
39
}

////////////////////////////////////////////////////////////////////////////////////////////////////

40
41
42
inline __device__ void operator+=(float2 &a, const float2 &b) {  // NOLINT(*)
  a.x += b.x;
  a.y += b.y;
Przemek Tredak's avatar
Przemek Tredak committed
43
44
45
46
}

////////////////////////////////////////////////////////////////////////////////////////////////////

47
template <typename T>
Przemek Tredak's avatar
Przemek Tredak committed
48
struct Sum {
49
50
  inline __device__ Sum() {}
  inline __device__ T operator()(const T &a, const T &b) const { return a + b; }
Przemek Tredak's avatar
Przemek Tredak committed
51
52
53
54
};

////////////////////////////////////////////////////////////////////////////////////////////////////

55
56
57
template <typename T>
inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) {
  return __shfl_xor_sync(static_cast<uint32_t>(-1), x, idx);
Przemek Tredak's avatar
Przemek Tredak committed
58
59
}

60
61
62
template <>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 &x, uint32_t idx) {
  return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)};
Przemek Tredak's avatar
Przemek Tredak committed
63
64
}

65
66
67
template <typename T>
inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) {
  return __shfl_down_sync(static_cast<uint32_t>(-1), x, idx);
Przemek Tredak's avatar
Przemek Tredak committed
68
69
}

70
71
72
template <>
inline __device__ float2 warp_shuffle_down<float2>(const float2 &x, uint32_t idx) {
  return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)};
Przemek Tredak's avatar
Przemek Tredak committed
73
74
75
76
77
78
79
80
81
}

////////////////////////////////////////////////////////////////////////////////////////////////////

namespace transformer_engine {

////////////////////////////////////////////////////////////////////////////////////////////////////

struct uint16 {
82
83
84
85
  uint4 u;
  uint4 v;
  uint4 s;
  uint4 t;
Przemek Tredak's avatar
Przemek Tredak committed
86
87
88
89
90
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct uint8 {
91
92
  uint4 u;
  uint4 v;
Przemek Tredak's avatar
Przemek Tredak committed
93
94
95
96
};

////////////////////////////////////////////////////////////////////////////////////////////////////

97
template <int BYTES>
Przemek Tredak's avatar
Przemek Tredak committed
98
99
struct BytesToType {};

100
template <>
Przemek Tredak's avatar
Przemek Tredak committed
101
struct BytesToType<64> {
102
103
  using Type = uint16;
  static_assert(sizeof(Type) == 64);
Przemek Tredak's avatar
Przemek Tredak committed
104
105
};

106
template <>
Przemek Tredak's avatar
Przemek Tredak committed
107
struct BytesToType<32> {
108
109
  using Type = uint8;
  static_assert(sizeof(Type) == 32);
Przemek Tredak's avatar
Przemek Tredak committed
110
111
};

112
template <>
Przemek Tredak's avatar
Przemek Tredak committed
113
struct BytesToType<16> {
114
115
  using Type = uint4;
  static_assert(sizeof(Type) == 16);
Przemek Tredak's avatar
Przemek Tredak committed
116
117
};

118
template <>
Przemek Tredak's avatar
Przemek Tredak committed
119
struct BytesToType<8> {
120
121
  using Type = uint64_t;
  static_assert(sizeof(Type) == 8);
Przemek Tredak's avatar
Przemek Tredak committed
122
123
};

124
template <>
Przemek Tredak's avatar
Przemek Tredak committed
125
struct BytesToType<4> {
126
127
  using Type = uint32_t;
  static_assert(sizeof(Type) == 4);
Przemek Tredak's avatar
Przemek Tredak committed
128
129
};

130
template <>
Przemek Tredak's avatar
Przemek Tredak committed
131
struct BytesToType<2> {
132
133
  using Type = uint16_t;
  static_assert(sizeof(Type) == 2);
Przemek Tredak's avatar
Przemek Tredak committed
134
135
};

136
template <>
Przemek Tredak's avatar
Przemek Tredak committed
137
struct BytesToType<1> {
138
139
  using Type = uint8_t;
  static_assert(sizeof(Type) == 1);
Przemek Tredak's avatar
Przemek Tredak committed
140
141
142
143
};

////////////////////////////////////////////////////////////////////////////////////////////////////

144
template <typename T>
Przemek Tredak's avatar
Przemek Tredak committed
145
146
struct TypeToVec2 {};

147
template <>
Przemek Tredak's avatar
Przemek Tredak committed
148
struct TypeToVec2<float> {
149
  using Type = float2;
Przemek Tredak's avatar
Przemek Tredak committed
150
151
};

152
template <>
Przemek Tredak's avatar
Przemek Tredak committed
153
struct TypeToVec2<half> {
154
  using Type = half2;
Przemek Tredak's avatar
Przemek Tredak committed
155
156
};

157
template <>
Przemek Tredak's avatar
Przemek Tredak committed
158
struct TypeToVec2<nv_bfloat16> {
159
  using Type = nv_bfloat162;
Przemek Tredak's avatar
Przemek Tredak committed
160
161
162
163
};

////////////////////////////////////////////////////////////////////////////////////////////////////

164
165
template <typename IType, typename IType2, typename OType, typename CType>
struct CTDBiasDActParam {
166
167
168
169
170
171
172
173
174
175
176
177
178
  using InputType = IType;
  using InputType2 = IType2;
  using OutputType = OType;
  using ComputeType = CType;
  const IType *input;
  const IType2 *act_input;
  OType *output_c;
  OType *output_t;
  const CType *scale_ptr;
  CType *amax;
  CType *scale_inv;
  CType *workspace;
  CType *warp_scales_inv;
179
180
181
182
};

////////////////////////////////////////////////////////////////////////////////////////////////////

183
template <int INDEX>
Przemek Tredak's avatar
Przemek Tredak committed
184
struct Get {
185
186
  template <typename T, typename R>
  static inline __device__ R of(const T &vec);
Przemek Tredak's avatar
Przemek Tredak committed
187
188
};

189
190
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
191
inline __device__ R Get<0>::of(const T &vec) {
192
  return vec.x;
Przemek Tredak's avatar
Przemek Tredak committed
193
194
}

195
196
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
197
inline __device__ R Get<1>::of(const T &vec) {
198
  return vec.y;
Przemek Tredak's avatar
Przemek Tredak committed
199
200
}

201
202
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
203
inline __device__ R Get<2>::of(const T &vec) {
204
  return vec.z;
Przemek Tredak's avatar
Przemek Tredak committed
205
206
}

207
208
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
209
inline __device__ R Get<3>::of(const T &vec) {
210
  return vec.w;
Przemek Tredak's avatar
Przemek Tredak committed
211
212
213
214
}

////////////////////////////////////////////////////////////////////////////////////////////////////

215
216
217
template <typename Src, typename Dst>
struct Converter {
  static inline __device__ Dst convert(const Src &from) { return Dst(from); }
Przemek Tredak's avatar
Przemek Tredak committed
218
219
};

220
221
222
template <>
struct Converter<float2, half2> {
  static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); }
Przemek Tredak's avatar
Przemek Tredak committed
223
224
};

225
226
227
template <>
struct Converter<float2, nv_bfloat162> {
  static inline __device__ nv_bfloat162 convert(const float2 &x) {
Przemek Tredak's avatar
Przemek Tredak committed
228
#if __CUDA_ARCH__ >= 800
229
    return __float22bfloat162_rn(x);
Przemek Tredak's avatar
Przemek Tredak committed
230
#else
231
232
233
234
235
236
237
    union {
      nv_bfloat162 raw;
      nv_bfloat16 elt[2];
    } tmp;
    tmp.elt[0] = __float2bfloat16_rn(x.x);
    tmp.elt[1] = __float2bfloat16_rn(x.y);
    return tmp.raw;
Przemek Tredak's avatar
Przemek Tredak committed
238
#endif
239
  }
Przemek Tredak's avatar
Przemek Tredak committed
240
241
242
243
};

////////////////////////////////////////////////////////////////////////////////////////////////////

244
245
246
template <typename T>
struct Zeros {
  static inline __device__ T get() { return T(0.f); }
Przemek Tredak's avatar
Przemek Tredak committed
247
248
};

249
250
251
template <>
struct Zeros<float2> {
  static inline __device__ float2 get() { return make_float2(0.f, 0.f); }
Przemek Tredak's avatar
Przemek Tredak committed
252
253
254
255
};

////////////////////////////////////////////////////////////////////////////////////////////////////

256
template <typename Elt_type, uint32_t NUM_ELT>
Przemek Tredak's avatar
Przemek Tredak committed
257
struct Vec {
258
  enum { BYTES = NUM_ELT * sizeof(Elt_type) };
Przemek Tredak's avatar
Przemek Tredak committed
259

260
261
  using Vec_type = typename BytesToType<BYTES>::Type;
  using type = Elt_type;
Przemek Tredak's avatar
Przemek Tredak committed
262

263
264
265
266
  using Alias_type = union {
    Vec_type vec;
    Elt_type elt[NUM_ELT];
  };
Przemek Tredak's avatar
Przemek Tredak committed
267

268
  Alias_type data;
Przemek Tredak's avatar
Przemek Tredak committed
269

270
271
272
273
274
  template <typename S>
  inline __device__ void to(Vec<S, NUM_ELT> &other) {  // NOLINT(*)
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      other.data.elt[it] = S(this->data.elt[it]);
Przemek Tredak's avatar
Przemek Tredak committed
275
    }
276
  }
Przemek Tredak's avatar
Przemek Tredak committed
277

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
  template <typename Op>
  inline __device__ void assign(const Op &op) {
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      this->data.elt[it] = op(it);
    }
  }

  // Pointer is cast to vector type
  inline __device__ void load_from(const void *base_ptr, size_t idx = 0) {
    this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
  }

  // Pointer is cast to vector type
  inline __device__ void store_to(void *base_ptr, size_t idx = 0) const {
    static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
  }

  // Pointer is cast to element type. Loads min(count, NUM_ELT)
  // elements and any remaining elements are set to zero.
  inline __device__ void load_from_elts(const void *base_ptr, size_t idx = 0,
                                        size_t count = NUM_ELT) {
    const Elt_type *elt_ptr = static_cast<const Elt_type *>(base_ptr) + idx;
    if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
      for (int it = 0; it < NUM_ELT; it++) {
        this->data.elt[it] = (it < count ? elt_ptr[it] : Elt_type(0.f));
      }
    } else {
      this->load_from(elt_ptr);
    }
  }

  // Pointer is cast to element type. Stores min(count, NUM_ELT)
  // elements.
  inline __device__ void store_to_elts(void *base_ptr, size_t idx = 0,
                                       size_t count = NUM_ELT) const {
    Elt_type *elt_ptr = static_cast<Elt_type *>(base_ptr) + idx;
    if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
      for (int it = 0; it < NUM_ELT; it++) {
        if (it < count) {
          elt_ptr[it] = this->data.elt[it];
Przemek Tredak's avatar
Przemek Tredak committed
321
        }
322
323
324
      }
    } else {
      this->store_to(elt_ptr);
Przemek Tredak's avatar
Przemek Tredak committed
325
    }
326
  }
Przemek Tredak's avatar
Przemek Tredak committed
327

328
329
330
331
  inline __device__ void clear() {
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      this->data.elt[it] = Elt_type(0.f);
Przemek Tredak's avatar
Przemek Tredak committed
332
    }
333
  }
Przemek Tredak's avatar
Przemek Tredak committed
334
335
336
337
338
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct InterCTASync {
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
  inline __device__ InterCTASync(int *barrier, int group, int num_groups, int group_size)
      : phase_counter_(0),
        b0_(barrier + group)  // The barrier for this group of CTAs.
        ,
        b1_(barrier + group + num_groups)  // The barrier for this group of CTAs.
        ,
        group_size_(group_size) {
    // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
  }

  inline __device__ void spin_wait_(int *barrier, int step, int expected) {
    asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
    for (int found = -1; found != expected;) {
      asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
    }
  }

  inline __device__ void sync() {
    // ALL THREADS MUST ENTER!

    // We switch barrier every iteration.
    int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
    // We decrement every other iteration.
    bool dec = phase_counter_ & 0x2;
    int step = dec ? -1 : 1;
    int expected = dec ? 0 : group_size_;
    // There are only 4 phases: up/down for b0/b1.
    phase_counter_ = (phase_counter_ + 1) & 0x3;

    if (threadIdx.x == 0) {
      spin_wait_(barrier, step, expected);
    }
    // CTA waits for thread 0
    __syncthreads();
  }
Przemek Tredak's avatar
Przemek Tredak committed
374

375
376
377
378
  int phase_counter_;
  int *b0_;
  int *b1_;
  int group_size_;
Przemek Tredak's avatar
Przemek Tredak committed
379
380
381
382
};

////////////////////////////////////////////////////////////////////////////////////////////////////

383
template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
384
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
  using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
  using Type = typename Base::Type;

  enum { SMEM_BYTES = Base::SMEM_BYTES };

  enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
  enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };

  // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
  enum {
    WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES
  };

  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {}

  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    data = Base::reduce(data, op);
    // We switch workspace every iteration.
    T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;

    // Warp leaders 0 hold the CTA-local results.
    if (this->warp_n_ == 0 && this->lane_ == 0) {
      workspace[bidn_] = data;
    }
    inter_cta_.sync();
    static_assert(CTAS_PER_ROW <= 32);
    T total = Zeros<T>::get();
    if (this->lane_ < CTAS_PER_ROW) {
      total = workspace[this->lane_];
    }
    total = Reducer<T, 1, 1, 1>::allreduce_(total, op);

    return total;
  }

  InterCTASync inter_cta_;

  T *const w0_;
  T *const w1_;
  int bidn_;
Przemek Tredak's avatar
Przemek Tredak committed
434
435
436
437
};

////////////////////////////////////////////////////////////////////////////////////////////////////

438
template <typename T, uint32_t WARPS_M>
Przemek Tredak's avatar
Przemek Tredak committed
439
struct Reducer<T, 1, WARPS_M, 1> {
440
441
442
  using Type = T;
  enum { SMEM_BYTES = 0 };
  enum { WORKSPACE_BYTES_PER_GROUP = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
443

444
445
446
447
448
449
450
451
452
453
454
455
  enum { THREADS_PER_WARP = 32 };

  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : warp_n_(warp_n), lane_(lane) {}

  template <typename Op>
  static inline __device__ T allreduce_(T data, const Op &op) {
#pragma unroll
    for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
      data = op(data, warp_shuffle_xor(data, it));
Przemek Tredak's avatar
Przemek Tredak committed
456
    }
457
458
    return data;
  }
Przemek Tredak's avatar
Przemek Tredak committed
459

460
461
462
463
464
465
466
467
468
469
470
  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    return allreduce_(data, op);
  }

  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
// only lane 0 holds the result!
#pragma unroll
    for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) {
      data = op(data, warp_shuffle_down(data, it));
Przemek Tredak's avatar
Przemek Tredak committed
471
    }
472
473
474
475
    return data;
  }
  int warp_n_;
  int lane_;
Przemek Tredak's avatar
Przemek Tredak committed
476
477
478
479
};

////////////////////////////////////////////////////////////////////////////////////////////////////

480
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
481
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
482
  using Base = Reducer<T, 1, WARPS_M, 1>;
Przemek Tredak's avatar
Przemek Tredak committed
483

484
  using Type = T;
Przemek Tredak's avatar
Przemek Tredak committed
485

486
487
  enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
  enum { WORKSPACE_BYTES_PER_GROUP = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
488

489
  enum { THREADS_PER_WARP = 32 };
Przemek Tredak's avatar
Przemek Tredak committed
490

491
492
493
494
495
496
497
  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        use0_(true),
        smem0_(&(static_cast<T *>(smem)[warp_m * WARPS_N])),
        smem1_(smem0_ + WARPS_M * WARPS_N) {}
Przemek Tredak's avatar
Przemek Tredak committed
498

499
500
501
502
503
504
505
  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    T *const smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    data = Base::reduce(data, op);
    if (this->lane_ == 0) {
      smem[this->warp_n_] = data;
Przemek Tredak's avatar
Przemek Tredak committed
506
    }
507
508
509
510
511
512
513
514
    __syncthreads();
    T out = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < WARPS_N; it++) {
      out = op(out, smem[it]);
    }
    return out;
  }
Przemek Tredak's avatar
Przemek Tredak committed
515

516
517
518
519
520
521
522
523
  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
    T *const smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    // only intra-CTA group leader holds the result!
    data = Base::reduce(data, op);
    if (this->lane_ == 0) {
      smem[this->warp_n_] = data;
Przemek Tredak's avatar
Przemek Tredak committed
524
    }
525
526
527
528
529
530
531
532
533
534
    __syncthreads();
    T out = Zeros<T>::get();
    if (this->warp_n_ == 0 && this->lane_ == 0) {
#pragma unroll
      for (int it = 0; it < WARPS_N; it++) {
        out = op(out, smem[it]);
      }
    }
    return out;
  }
Przemek Tredak's avatar
Przemek Tredak committed
535

536
537
538
  T *const smem0_;
  T *const smem1_;
  bool use0_;
Przemek Tredak's avatar
Przemek Tredak committed
539
540
541
542
};

////////////////////////////////////////////////////////////////////////////////////////////////////

543
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
544
struct DynamicReducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
  using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
  using Type = typename Base::Type;

  template <typename Params>
  inline __device__ DynamicReducer(const Params &params, uint32_t bidm, uint32_t bidn,
                                   uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        inter_cta_(params.barrier, bidm, params.ctas_per_col, params.ctas_per_row),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * params.ctas_per_row),
        w1_(w0_ + params.ctas_per_col * WARPS_M * params.ctas_per_row) {}

  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    // Trivial case
    if (inter_cta_.group_size_ == 1) {
      return Base::allreduce(data, op);
    }

    data = Base::reduce(data, op);
    // We switch workspace every iteration.
    T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;

    // Warp leaders 0 hold the CTA-local results.
    if (this->warp_n_ == 0 && this->lane_ == 0) {
      workspace[bidn_] = data;
    }
    inter_cta_.sync();
    T total = Zeros<T>::get();
    for (int it = this->lane_; it < inter_cta_.group_size_; it += THREADS_PER_WARP) {
      total = op(total, workspace[it]);
    }
    total = Reducer<T, 1, 1, 1>::allreduce_(total, op);

    return total;
  }

  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
    return allreduce(data, op);
  }

  InterCTASync inter_cta_;

  T *const w0_;
  T *const w1_;
  int bidn_;
Przemek Tredak's avatar
Przemek Tredak committed
593
594
595
596
};

////////////////////////////////////////////////////////////////////////////////////////////////////

597
598
599
600
601
602
603
604
605
606
607
608
609
/*
This is an implementation of the parallel Welford algorithm for incrementally computing variance

This algorithm is known as Chan's update formulae (Chat et al '79):
http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf

An introduction is provided by Wikipedia here:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance?section=5#Parallel_algorithm

A detailed reference on the exact version implemented (with better numerical stability) is provided here:
https://dbs.ifi.uni-heidelberg.de/files/Team/eschubert/publications/SSDBM18-covariance-authorcopy.pdf
*/

610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
template <typename T>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a,
                                             int num_active) {  // NOLINT(*)
  // Assume at least leftmost is valid and
  // init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
  int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);

#pragma unroll
  for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) {
    // Exchange
    T n_b = warp_shuffle_down(n_a, step);
    T m_b = warp_shuffle_down(m_a, step);
    T m2_b = warp_shuffle_down(m2_a, step);

    // Update
    const T n_ab = n_a + n_b;  // We can handle one of them being 0, not both.
    // Might have different n per thread, otherwise this would simplify :(
    const T rn_ab = 1.f / n_ab;
    const T delta = m_a - m_b;
    const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
    const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;

    n_a = n_ab;
    m_a = m_ab;
    m2_a = m2_ab;
  }
  // Intra-warp broadcast (only lane 0 has valid stats).
  m_a = __shfl_sync(static_cast<uint32_t>(-1), m_a, 0);
  m2_a = __shfl_sync(static_cast<uint32_t>(-1), m2_a, 0);
Przemek Tredak's avatar
Przemek Tredak committed
639
640
641
642
}

////////////////////////////////////////////////////////////////////////////////////////////////////

643
template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
644
struct Stats {
645
646
  // This could be done generically with the Reducer. But then we
  // would have to exchange 3 instead of 2 fields.
Przemek Tredak's avatar
Przemek Tredak committed
647

648
649
  using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
  using stats_t = typename BlockStats::stats_t;
Przemek Tredak's avatar
Przemek Tredak committed
650

651
  enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
Przemek Tredak's avatar
Przemek Tredak committed
652

653
654
655
656
657
658
659
660
661
662
663
  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
        block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<stats_t *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW),
        warp_n_(warp_n),
        lane_(lane) {}
Przemek Tredak's avatar
Przemek Tredak committed
664

665
666
667
668
669
670
  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
    // TODO(ptredak) rn is not really needed here..
    constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
    stats_t block_stats = block_stats_.compute(elts, block_rn);
Przemek Tredak's avatar
Przemek Tredak committed
671

672
    stats_t *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
Przemek Tredak's avatar
Przemek Tredak committed
673

674
675
    if (warp_n_ == 0 && lane_ == 0) {
      workspace[bidn_] = block_stats;
Przemek Tredak's avatar
Przemek Tredak committed
676
677
    }

678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
    // Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
    inter_cta_.sync();

    T n = Zeros<T>::get();
    T m = Zeros<T>::get();
    T m2 = Zeros<T>::get();

    // Assume CTA group size in N less than 32, such that we can finalize with a single warp.
    static_assert(CTAS_PER_ROW <= 32);

    // Every warp does the final reduction locally.
    if (lane_ < CTAS_PER_ROW) {
      stats_t result = workspace[lane_];
      n = ELTS_PER_ROW_PER_CTA;
      m = transformer_engine::Get<0>::of<stats_t, T>(result);
      m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
    }

    warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);

    return {m, m2};
  }
Przemek Tredak's avatar
Przemek Tredak committed
700

701
702
703
704
705
706
707
708
  InterCTASync inter_cta_;
  BlockStats block_stats_;

  stats_t *const w0_;
  stats_t *const w1_;
  int bidn_;
  int warp_n_;
  int lane_;
Przemek Tredak's avatar
Przemek Tredak committed
709
710
711
712
};

////////////////////////////////////////////////////////////////////////////////////////////////////

713
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
714
struct Stats<T, 1, WARPS_M, WARPS_N> {
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
  using WarpStats = Stats<T, 1, WARPS_M, 1>;
  using stats_t = typename WarpStats::stats_t;

  enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };

  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) {
    smem0_ = static_cast<stats_t *>(smem) + warp_m * WARPS_N;
    smem1_ = smem0_ + WARPS_M * WARPS_N;
  }

  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    stats_t *smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    // Compute warp local for all WARPS_N
    constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
    stats_t warp_stats = warp_stats_.compute(elts, warp_rn);

    // Each warp warp leader stores its stats
    const auto warp_n = warp_stats_.reducer_.warp_n_;
    const auto lane = warp_stats_.reducer_.lane_;
    if (lane == 0) {
      smem[warp_n] = warp_stats;
Przemek Tredak's avatar
Przemek Tredak committed
741
    }
742
    __syncthreads();
Przemek Tredak's avatar
Przemek Tredak committed
743

744
745
746
    T n = Zeros<T>::get();
    T m = Zeros<T>::get();
    T m2 = Zeros<T>::get();
Przemek Tredak's avatar
Przemek Tredak committed
747

748
749
750
751
752
753
754
    // Assume that there are less than 32 warps, such that we can finalize with a single warp
    static_assert(WARPS_N <= 32);
    if (lane < WARPS_N) {
      stats_t result = smem[lane];
      n = N * THREADS_PER_WARP;
      m = transformer_engine::Get<0>::of<stats_t, T>(result);
      m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
Przemek Tredak's avatar
Przemek Tredak committed
755
    }
756
757
758
759
760
761
762
763
764

    warp_chan_upd_dynamic(m, m2, n, WARPS_N);

    return {m, m2};
  }
  WarpStats warp_stats_;
  stats_t *smem0_;
  stats_t *smem1_;
  bool use0_;
Przemek Tredak's avatar
Przemek Tredak committed
765
766
767
768
};

////////////////////////////////////////////////////////////////////////////////////////////////////

769
template <typename T, uint32_t WARPS_M>
Przemek Tredak's avatar
Przemek Tredak committed
770
struct Stats<T, 1, WARPS_M, 1> {
771
772
773
  using stats_t = typename TypeToVec2<T>::Type;
  // The simple Warp reducer.
  using Reducer = Reducer<T, 1, WARPS_M, 1>;
Przemek Tredak's avatar
Przemek Tredak committed
774

775
  enum { SMEM_BYTES = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
776

777
778
779
780
  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {}
Przemek Tredak's avatar
Przemek Tredak committed
781

782
783
784
  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    auto sum = Sum<T>();
Przemek Tredak's avatar
Przemek Tredak committed
785

786
787
788
789
790
791
    T m = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < N; it++) {
      m += elts[it];
    }
    m = reducer_.allreduce(m, sum) * rn;
Przemek Tredak's avatar
Przemek Tredak committed
792

793
794
795
796
797
    T m2 = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < N; it++) {
      T diff = (elts[it] - m);
      m2 += diff * diff;
Przemek Tredak's avatar
Przemek Tredak committed
798
    }
799
800
801
802
    m2 = reducer_.allreduce(m2, sum);

    return {m, m2};
  }
Przemek Tredak's avatar
Przemek Tredak committed
803

804
  Reducer reducer_;
Przemek Tredak's avatar
Przemek Tredak committed
805
806
807
808
809
810
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int num_elems>
__device__ __forceinline__ float warp_reduce_max(const float m) {
811
  float tmp = m;
Przemek Tredak's avatar
Przemek Tredak committed
812
#pragma unroll
813
814
815
816
817
818
819
  for (int delta = num_elems / 2; delta > 0; delta /= 2) {
    const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta);
    __builtin_assume(tmp >= 0);
    __builtin_assume(other_m >= 0);
    tmp = fmaxf(tmp, other_m);
  }
  return tmp;
Przemek Tredak's avatar
Przemek Tredak committed
820
821
}

822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
__forceinline__ __device__ float warp_reduce_max_broadcast(const float val) {
  float val_tmp = val;
#pragma unroll
  for (int offset = THREADS_PER_WARP / 2; offset > 0; offset /= 2) {
    const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset);
    __builtin_assume(val_tmp >= 0);
    __builtin_assume(val_other >= 0);
    val_tmp = fmaxf(val_tmp, val_other);
  }
  // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
  constexpr int subwarp_lane_zero = 0;
  val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero);
  return val_tmp;
}

Przemek Tredak's avatar
Przemek Tredak committed
837
838
template <int num_warps, typename compute_t>
__device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) {
839
840
841
842
843
844
845
846
  __shared__ float staging[num_warps];
  constexpr int warp_size = 32;
  const float my_max = m;
  const float my_warp_max = warp_reduce_max<warp_size>(my_max);
  if (threadIdx.x % 32 == 0) {
    staging[warpid] = my_warp_max;
  }
  __syncthreads();
847
  compute_t result = 0.f;
848
849
850
851
852
  if (warpid == 0) {
    const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0;
    result = warp_reduce_max<num_warps>(my_max);
  }
  return result;
Przemek Tredak's avatar
Przemek Tredak committed
853
854
}

855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
/**
 * Max reduction in subwarps
 * E.g., if nvec=4, each warp processes 128 elements (32 x 4), that covers four MXFP8 scaling factors.
 * To compute an actual scaling factor for 32 consequentive elements, only 8 threads need to participate,
 * thus splitting the warp into 4x smaller subwarps 8-thread width.
 * 'Butterfly' reduction is used inside subwarps.
 */
template <int subwarp_width>
__forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) {
  float val_tmp = val;
#pragma unroll
  for (int offset = subwarp_width / 2; offset > 0; offset /= 2) {
    const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width);
    __builtin_assume(val_tmp >= 0);
    __builtin_assume(val_other >= 0);
    val_tmp = fmaxf(val_tmp, val_other);
  }
  // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
  constexpr int subwarp_lane_zero = 0;
  val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width);
  return val_tmp;
}

Przemek Tredak's avatar
Przemek Tredak committed
878
// Works only on positive values
879
880
__device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) {
  atomicMax(reinterpret_cast<int *>(addr), __float_as_int(value));
Przemek Tredak's avatar
Przemek Tredak committed
881
882
883
}

// Works only on positive values
884
885
__device__ __forceinline__ void atomicMinFloat(float *addr, const float value) {
  atomicMin(reinterpret_cast<int *>(addr), __float_as_int(value));
Przemek Tredak's avatar
Przemek Tredak committed
886
887
888
}

template <typename T>
889
890
__device__ __forceinline__ void reciprocal(T *value_inv, const T value) {
  *value_inv = 1 / value;
Przemek Tredak's avatar
Przemek Tredak committed
891
892
}

893
894
895
896
897
template <>
__device__ __forceinline__ void reciprocal<float>(float *value_inv, const float value) {
  *value_inv = __frcp_rn(value);
}

898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
////////////////////////////////////////////////////////////////////////////////////////////////////

using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using e8m0_t = uint8_t;

constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127;

enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 };

template <typename T>
struct Numeric_Traits;

template <>
struct Numeric_Traits<fp8e4m3> {
  static constexpr int maxUnbiasedExponent = 8;
  static constexpr double maxNorm = 448;
};

template <>
struct Numeric_Traits<fp8e5m2> {
  static constexpr int maxUnbiasedExponent = 15;
  static constexpr double maxNorm = 57344;
};

template <typename T>
struct Quantized_Limits {
  static constexpr int max_unbiased_exponent = Numeric_Traits<T>::maxUnbiasedExponent;
  static constexpr float max_norm = Numeric_Traits<T>::maxNorm;
  static constexpr float max_norm_rcp = 1.0 / max_norm;
  static constexpr float emax = 1 << max_unbiased_exponent;
  static constexpr float emax_rcp = 1.0 / emax;
};

__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
  // TODO: nan/inf needs to be set for any value
  // of nan/inf in input not just amax.
  if (isnan(val)) {
    return 0xFF;
  }
  if (isinf(val)) {
    return 0xFE;
  }
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
     (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
  uint16_t out;
  asm volatile(
      "{\n"
      "cvt.rp.satfinite.ue8m0x2.f32  %0, 0.0, %1;\n"
      "}"
      : "=h"(out)
      : "f"(val));
  return *reinterpret_cast<e8m0_t *>(&out);
#else
  if (val == 0.0f) {
    return 0x00;
  }
  uint32_t val_u32 = *reinterpret_cast<uint32_t *>(&val);
  e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
  uint32_t mantissa = val_u32 & 0x7FFFFF;
  // Round up exponent and deal with satfinite.
  if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
    ++exponent;
  }
  return exponent;
#endif
}

__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
  return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(biased_exp));
}

Przemek Tredak's avatar
Przemek Tredak committed
971
972
973
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_