utils.cuh 31.6 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
#define TRANSFORMER_ENGINE_COMMON_UTILS_CUH_

#include <cuda_fp16.h>
yuguo's avatar
yuguo committed
11
12
13
14
15

#ifdef __HIP_PLATFORM_AMD__
#ifndef __HIPCC_RTC__
#include <cstdint>
#else
16
#include <hip/amd_detail/hip_assert.h>
yuguo's avatar
yuguo committed
17
18
19
20
21
using namespace __hip_internal;
#endif
#endif

#include <cuda_bf16.h>
Tim Moon's avatar
Tim Moon committed
22
#include <cuda_fp8.h>
Przemek Tredak's avatar
Przemek Tredak committed
23

24
25
26
27
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif

yuguo's avatar
yuguo committed
28
29
30
#ifdef __HIP_PLATFORM_AMD__
typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2)));
#else
Tim Moon's avatar
Tim Moon committed
31
#if !defined(__CUDACC_RTC__)
32
#include <cassert>
Tim Moon's avatar
Tim Moon committed
33
34
35
36
37
38
39
40
41
42
43
44
#include <cstdint>
#else
// Importing C++ standard headers is a pain with NVRTC
using uint8_t = unsigned char;
using uint16_t = unsigned short int;  // NOLINT(*)
using uint32_t = unsigned int;
using uint64_t = unsigned long long int;  // NOLINT(*)
static_assert(sizeof(uint8_t) == 1);
static_assert(sizeof(uint16_t) == 2);
static_assert(sizeof(uint32_t) == 4);
static_assert(sizeof(uint64_t) == 8);
#endif
Przemek Tredak's avatar
Przemek Tredak committed
45

yuguo's avatar
yuguo committed
46
#endif // __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
47
48
49
50
51
52
////////////////////////////////////////////////////////////////////////////////////////////////////

constexpr uint32_t THREADS_PER_WARP = 32;

////////////////////////////////////////////////////////////////////////////////////////////////////

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
// Device-side error
#define NVTE_DEVICE_ERROR(message)                                                                 \
  do {                                                                                             \
    printf("%s:%d in function %s (thread (%d,%d,%d), block (%d,%d,%d)): %s\n", __FILE__, __LINE__, \
           __func__, threadIdx.x, threadIdx.y, threadIdx.z, blockIdx.x, blockIdx.y, blockIdx.z,    \
           (message));                                                                             \
    assert(0);                                                                                     \
  } while (false)

// Device-side error on thread 0
#define NVTE_DEVICE_THREAD0_ERROR(message)                                           \
  do {                                                                               \
    if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && \
        threadIdx.y == 0 && threadIdx.z == 0) {                                      \
      NVTE_DEVICE_ERROR(message);                                                    \
    }                                                                                \
  } while (false)

////////////////////////////////////////////////////////////////////////////////////////////////////

yuguo's avatar
yuguo committed
73
#if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__)
74
75
inline __device__ float2 operator+(const float2 &a, const float2 &b) {  // NOLINT(*)
  return {a.x + b.x, a.y + b.y};
Przemek Tredak's avatar
Przemek Tredak committed
76
77
78
79
}

////////////////////////////////////////////////////////////////////////////////////////////////////

80
81
82
inline __device__ void operator+=(float2 &a, const float2 &b) {  // NOLINT(*)
  a.x += b.x;
  a.y += b.y;
Przemek Tredak's avatar
Przemek Tredak committed
83
}
yuguo's avatar
yuguo committed
84
#endif
Przemek Tredak's avatar
Przemek Tredak committed
85
86
87

////////////////////////////////////////////////////////////////////////////////////////////////////

88
template <typename T>
Przemek Tredak's avatar
Przemek Tredak committed
89
struct Sum {
90
91
  inline __device__ Sum() {}
  inline __device__ T operator()(const T &a, const T &b) const { return a + b; }
Przemek Tredak's avatar
Przemek Tredak committed
92
93
94
95
};

////////////////////////////////////////////////////////////////////////////////////////////////////

96
97
template <typename T>
inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) {
yuguo's avatar
yuguo committed
98
99
100
#ifdef __HIP_PLATFORM_AMD__
  return __shfl_xor(x, idx, THREADS_PER_WARP);
#else
101
  return __shfl_xor_sync(static_cast<uint32_t>(-1), x, idx);
yuguo's avatar
yuguo committed
102
#endif
Przemek Tredak's avatar
Przemek Tredak committed
103
104
}

105
106
107
template <>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 &x, uint32_t idx) {
  return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)};
Przemek Tredak's avatar
Przemek Tredak committed
108
109
}

110
111
template <typename T>
inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) {
yuguo's avatar
yuguo committed
112
113
114
#ifdef __HIP_PLATFORM_AMD__
  return __shfl_down(x, idx, THREADS_PER_WARP);
#else
115
  return __shfl_down_sync(static_cast<uint32_t>(-1), x, idx);
yuguo's avatar
yuguo committed
116
#endif
Przemek Tredak's avatar
Przemek Tredak committed
117
118
}

119
120
121
template <>
inline __device__ float2 warp_shuffle_down<float2>(const float2 &x, uint32_t idx) {
  return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)};
Przemek Tredak's avatar
Przemek Tredak committed
122
123
124
125
126
127
128
129
130
}

////////////////////////////////////////////////////////////////////////////////////////////////////

namespace transformer_engine {

////////////////////////////////////////////////////////////////////////////////////////////////////

struct uint16 {
131
132
133
134
  uint4 u;
  uint4 v;
  uint4 s;
  uint4 t;
Przemek Tredak's avatar
Przemek Tredak committed
135
136
137
138
139
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct uint8 {
140
141
  uint4 u;
  uint4 v;
Przemek Tredak's avatar
Przemek Tredak committed
142
143
144
145
};

////////////////////////////////////////////////////////////////////////////////////////////////////

146
template <int BYTES>
Przemek Tredak's avatar
Przemek Tredak committed
147
148
struct BytesToType {};

149
template <>
Przemek Tredak's avatar
Przemek Tredak committed
150
struct BytesToType<64> {
151
152
  using Type = uint16;
  static_assert(sizeof(Type) == 64);
Przemek Tredak's avatar
Przemek Tredak committed
153
154
};

155
template <>
Przemek Tredak's avatar
Przemek Tredak committed
156
struct BytesToType<32> {
157
158
  using Type = uint8;
  static_assert(sizeof(Type) == 32);
Przemek Tredak's avatar
Przemek Tredak committed
159
160
};

161
template <>
Przemek Tredak's avatar
Przemek Tredak committed
162
struct BytesToType<16> {
163
164
  using Type = uint4;
  static_assert(sizeof(Type) == 16);
Przemek Tredak's avatar
Przemek Tredak committed
165
166
};

167
template <>
Przemek Tredak's avatar
Przemek Tredak committed
168
struct BytesToType<8> {
169
170
  using Type = uint64_t;
  static_assert(sizeof(Type) == 8);
Przemek Tredak's avatar
Przemek Tredak committed
171
172
};

173
template <>
Przemek Tredak's avatar
Przemek Tredak committed
174
struct BytesToType<4> {
175
176
  using Type = uint32_t;
  static_assert(sizeof(Type) == 4);
Przemek Tredak's avatar
Przemek Tredak committed
177
178
};

179
template <>
Przemek Tredak's avatar
Przemek Tredak committed
180
struct BytesToType<2> {
181
182
  using Type = uint16_t;
  static_assert(sizeof(Type) == 2);
Przemek Tredak's avatar
Przemek Tredak committed
183
184
};

185
template <>
Przemek Tredak's avatar
Przemek Tredak committed
186
struct BytesToType<1> {
187
188
  using Type = uint8_t;
  static_assert(sizeof(Type) == 1);
Przemek Tredak's avatar
Przemek Tredak committed
189
190
191
192
};

////////////////////////////////////////////////////////////////////////////////////////////////////

193
template <typename T>
Przemek Tredak's avatar
Przemek Tredak committed
194
195
struct TypeToVec2 {};

196
template <>
Przemek Tredak's avatar
Przemek Tredak committed
197
struct TypeToVec2<float> {
198
  using Type = float2;
Przemek Tredak's avatar
Przemek Tredak committed
199
200
};

201
template <>
Przemek Tredak's avatar
Przemek Tredak committed
202
struct TypeToVec2<half> {
203
  using Type = half2;
Przemek Tredak's avatar
Przemek Tredak committed
204
205
};

yuguo's avatar
yuguo committed
206
207
208
209
210
211
#ifdef __HIP_PLATFORM_AMD__
template <>
struct TypeToVec2<__hip_bfloat16> {
  using Type = hip_bfloat16x2;
};
#else
212
template <>
Przemek Tredak's avatar
Przemek Tredak committed
213
struct TypeToVec2<nv_bfloat16> {
214
  using Type = nv_bfloat162;
Przemek Tredak's avatar
Przemek Tredak committed
215
};
yuguo's avatar
yuguo committed
216
#endif
Przemek Tredak's avatar
Przemek Tredak committed
217
218
219

////////////////////////////////////////////////////////////////////////////////////////////////////

220
221
template <typename IType, typename IType2, typename OType, typename CType>
struct CTDBiasDActParam {
222
223
224
225
226
227
228
229
230
231
232
233
234
  using InputType = IType;
  using InputType2 = IType2;
  using OutputType = OType;
  using ComputeType = CType;
  const IType *input;
  const IType2 *act_input;
  OType *output_c;
  OType *output_t;
  const CType *scale_ptr;
  CType *amax;
  CType *scale_inv;
  CType *workspace;
  CType *warp_scales_inv;
235
236
237
238
};

////////////////////////////////////////////////////////////////////////////////////////////////////

239
template <int INDEX>
Przemek Tredak's avatar
Przemek Tredak committed
240
struct Get {
241
242
  template <typename T, typename R>
  static inline __device__ R of(const T &vec);
Przemek Tredak's avatar
Przemek Tredak committed
243
244
};

245
246
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
247
inline __device__ R Get<0>::of(const T &vec) {
248
  return vec.x;
Przemek Tredak's avatar
Przemek Tredak committed
249
250
}

251
252
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
253
inline __device__ R Get<1>::of(const T &vec) {
254
  return vec.y;
Przemek Tredak's avatar
Przemek Tredak committed
255
256
}

257
258
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
259
inline __device__ R Get<2>::of(const T &vec) {
260
  return vec.z;
Przemek Tredak's avatar
Przemek Tredak committed
261
262
}

263
264
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
265
inline __device__ R Get<3>::of(const T &vec) {
266
  return vec.w;
Przemek Tredak's avatar
Przemek Tredak committed
267
268
269
270
}

////////////////////////////////////////////////////////////////////////////////////////////////////

271
272
273
template <typename Src, typename Dst>
struct Converter {
  static inline __device__ Dst convert(const Src &from) { return Dst(from); }
Przemek Tredak's avatar
Przemek Tredak committed
274
275
};

276
277
278
template <>
struct Converter<float2, half2> {
  static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); }
Przemek Tredak's avatar
Przemek Tredak committed
279
280
};

yuguo's avatar
yuguo committed
281
282
283
284
285
286
#ifdef __HIP_PLATFORM_AMD__
template <>
struct Converter<float2, hip_bfloat16x2> {
  static inline __device__ hip_bfloat16x2 convert(const float2 &x) {
    union {
      hip_bfloat16x2 raw;
yuguo's avatar
yuguo committed
287
      __hip_bfloat16 elt[2];
yuguo's avatar
yuguo committed
288
289
290
291
292
293
294
    } tmp;
    tmp.elt[0] = __hip_bfloat16(x.x);
    tmp.elt[1] = __hip_bfloat16(x.y);
    return tmp.raw;
  }
};
#else
295
296
297
template <>
struct Converter<float2, nv_bfloat162> {
  static inline __device__ nv_bfloat162 convert(const float2 &x) {
Przemek Tredak's avatar
Przemek Tredak committed
298
#if __CUDA_ARCH__ >= 800
299
    return __float22bfloat162_rn(x);
Przemek Tredak's avatar
Przemek Tredak committed
300
#else
301
302
303
304
305
306
307
    union {
      nv_bfloat162 raw;
      nv_bfloat16 elt[2];
    } tmp;
    tmp.elt[0] = __float2bfloat16_rn(x.x);
    tmp.elt[1] = __float2bfloat16_rn(x.y);
    return tmp.raw;
Przemek Tredak's avatar
Przemek Tredak committed
308
#endif
309
  }
Przemek Tredak's avatar
Przemek Tredak committed
310
};
yuguo's avatar
yuguo committed
311
#endif
Przemek Tredak's avatar
Przemek Tredak committed
312
313
314

////////////////////////////////////////////////////////////////////////////////////////////////////

315
316
317
template <typename T>
struct Zeros {
  static inline __device__ T get() { return T(0.f); }
Przemek Tredak's avatar
Przemek Tredak committed
318
319
};

320
321
322
template <>
struct Zeros<float2> {
  static inline __device__ float2 get() { return make_float2(0.f, 0.f); }
Przemek Tredak's avatar
Przemek Tredak committed
323
324
325
326
};

////////////////////////////////////////////////////////////////////////////////////////////////////

327
template <typename Elt_type, uint32_t NUM_ELT>
Przemek Tredak's avatar
Przemek Tredak committed
328
struct Vec {
329
  enum { BYTES = NUM_ELT * sizeof(Elt_type) };
Przemek Tredak's avatar
Przemek Tredak committed
330

331
332
  using Vec_type = typename BytesToType<BYTES>::Type;
  using type = Elt_type;
Przemek Tredak's avatar
Przemek Tredak committed
333

334
335
336
337
  using Alias_type = union {
    Vec_type vec;
    Elt_type elt[NUM_ELT];
  };
Przemek Tredak's avatar
Przemek Tredak committed
338

339
  Alias_type data;
yuguo's avatar
yuguo committed
340
341
342
343
344
345
#ifdef __HIP_PLATFORM_AMD__
  __HOST_DEVICE__ Vec& operator=(const Vec& rhs) {
    data.vec = rhs.data.vec;
    return *this;
  }
#endif
Przemek Tredak's avatar
Przemek Tredak committed
346

347
348
349
350
351
  template <typename S>
  inline __device__ void to(Vec<S, NUM_ELT> &other) {  // NOLINT(*)
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      other.data.elt[it] = S(this->data.elt[it]);
Przemek Tredak's avatar
Przemek Tredak committed
352
    }
353
  }
Przemek Tredak's avatar
Przemek Tredak committed
354

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
  template <typename Op>
  inline __device__ void assign(const Op &op) {
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      this->data.elt[it] = op(it);
    }
  }

  // Pointer is cast to vector type
  inline __device__ void load_from(const void *base_ptr, size_t idx = 0) {
    this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
  }

  // Pointer is cast to vector type
  inline __device__ void store_to(void *base_ptr, size_t idx = 0) const {
    static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
  }

  // Pointer is cast to element type. Loads min(count, NUM_ELT)
  // elements and any remaining elements are set to zero.
  inline __device__ void load_from_elts(const void *base_ptr, size_t idx = 0,
                                        size_t count = NUM_ELT) {
    const Elt_type *elt_ptr = static_cast<const Elt_type *>(base_ptr) + idx;
    if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
      for (int it = 0; it < NUM_ELT; it++) {
        this->data.elt[it] = (it < count ? elt_ptr[it] : Elt_type(0.f));
      }
    } else {
      this->load_from(elt_ptr);
    }
  }

  // Pointer is cast to element type. Stores min(count, NUM_ELT)
  // elements.
  inline __device__ void store_to_elts(void *base_ptr, size_t idx = 0,
                                       size_t count = NUM_ELT) const {
    Elt_type *elt_ptr = static_cast<Elt_type *>(base_ptr) + idx;
    if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
      for (int it = 0; it < NUM_ELT; it++) {
        if (it < count) {
          elt_ptr[it] = this->data.elt[it];
Przemek Tredak's avatar
Przemek Tredak committed
398
        }
399
400
401
      }
    } else {
      this->store_to(elt_ptr);
Przemek Tredak's avatar
Przemek Tredak committed
402
    }
403
  }
Przemek Tredak's avatar
Przemek Tredak committed
404

405
406
407
408
  inline __device__ void clear() {
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      this->data.elt[it] = Elt_type(0.f);
Przemek Tredak's avatar
Przemek Tredak committed
409
    }
410
  }
Przemek Tredak's avatar
Przemek Tredak committed
411
412
413
414
415
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct InterCTASync {
416
417
418
419
420
421
422
423
424
425
  inline __device__ InterCTASync(int *barrier, int group, int num_groups, int group_size)
      : phase_counter_(0),
        b0_(barrier + group)  // The barrier for this group of CTAs.
        ,
        b1_(barrier + group + num_groups)  // The barrier for this group of CTAs.
        ,
        group_size_(group_size) {
    // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
  }

yuguo's avatar
yuguo committed
426
427
428
429
430
431
432
433
#ifdef __HIP_PLATFORM_AMD__
  inline __device__ void spin_wait_(int *barrier, int step, int expected) {
    __hip_atomic_fetch_add(barrier, step, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
    for (int found = -1; found != expected; ) {
      found = __hip_atomic_load(barrier, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT);
    }
  }
#else
434
435
436
437
438
439
  inline __device__ void spin_wait_(int *barrier, int step, int expected) {
    asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
    for (int found = -1; found != expected;) {
      asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
    }
  }
yuguo's avatar
yuguo committed
440
#endif
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459

  inline __device__ void sync() {
    // ALL THREADS MUST ENTER!

    // We switch barrier every iteration.
    int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
    // We decrement every other iteration.
    bool dec = phase_counter_ & 0x2;
    int step = dec ? -1 : 1;
    int expected = dec ? 0 : group_size_;
    // There are only 4 phases: up/down for b0/b1.
    phase_counter_ = (phase_counter_ + 1) & 0x3;

    if (threadIdx.x == 0) {
      spin_wait_(barrier, step, expected);
    }
    // CTA waits for thread 0
    __syncthreads();
  }
Przemek Tredak's avatar
Przemek Tredak committed
460

461
462
463
464
  int phase_counter_;
  int *b0_;
  int *b1_;
  int group_size_;
Przemek Tredak's avatar
Przemek Tredak committed
465
466
467
468
};

////////////////////////////////////////////////////////////////////////////////////////////////////

469
template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
470
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
  using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
  using Type = typename Base::Type;

  enum { SMEM_BYTES = Base::SMEM_BYTES };

  enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
  enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };

  // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
  enum {
    WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES
  };

  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {}

  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    data = Base::reduce(data, op);
    // We switch workspace every iteration.
    T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;

    // Warp leaders 0 hold the CTA-local results.
    if (this->warp_n_ == 0 && this->lane_ == 0) {
      workspace[bidn_] = data;
    }
    inter_cta_.sync();
    static_assert(CTAS_PER_ROW <= 32);
    T total = Zeros<T>::get();
    if (this->lane_ < CTAS_PER_ROW) {
      total = workspace[this->lane_];
    }
    total = Reducer<T, 1, 1, 1>::allreduce_(total, op);

    return total;
  }

  InterCTASync inter_cta_;

  T *const w0_;
  T *const w1_;
  int bidn_;
Przemek Tredak's avatar
Przemek Tredak committed
520
521
522
523
};

////////////////////////////////////////////////////////////////////////////////////////////////////

524
template <typename T, uint32_t WARPS_M>
Przemek Tredak's avatar
Przemek Tredak committed
525
struct Reducer<T, 1, WARPS_M, 1> {
526
527
528
  using Type = T;
  enum { SMEM_BYTES = 0 };
  enum { WORKSPACE_BYTES_PER_GROUP = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
529

530
531
532
533
534
535
536
537
538
539
540
541
  enum { THREADS_PER_WARP = 32 };

  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : warp_n_(warp_n), lane_(lane) {}

  template <typename Op>
  static inline __device__ T allreduce_(T data, const Op &op) {
#pragma unroll
    for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
      data = op(data, warp_shuffle_xor(data, it));
Przemek Tredak's avatar
Przemek Tredak committed
542
    }
543
544
    return data;
  }
Przemek Tredak's avatar
Przemek Tredak committed
545

546
547
548
549
550
551
552
553
554
555
556
  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    return allreduce_(data, op);
  }

  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
// only lane 0 holds the result!
#pragma unroll
    for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) {
      data = op(data, warp_shuffle_down(data, it));
Przemek Tredak's avatar
Przemek Tredak committed
557
    }
558
559
560
561
    return data;
  }
  int warp_n_;
  int lane_;
Przemek Tredak's avatar
Przemek Tredak committed
562
563
564
565
};

////////////////////////////////////////////////////////////////////////////////////////////////////

566
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
567
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
568
  using Base = Reducer<T, 1, WARPS_M, 1>;
Przemek Tredak's avatar
Przemek Tredak committed
569

570
  using Type = T;
Przemek Tredak's avatar
Przemek Tredak committed
571

572
573
  enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
  enum { WORKSPACE_BYTES_PER_GROUP = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
574

575
  enum { THREADS_PER_WARP = 32 };
Przemek Tredak's avatar
Przemek Tredak committed
576

577
578
579
580
581
582
583
  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        use0_(true),
        smem0_(&(static_cast<T *>(smem)[warp_m * WARPS_N])),
        smem1_(smem0_ + WARPS_M * WARPS_N) {}
Przemek Tredak's avatar
Przemek Tredak committed
584

585
586
587
588
589
590
591
  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    T *const smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    data = Base::reduce(data, op);
    if (this->lane_ == 0) {
      smem[this->warp_n_] = data;
Przemek Tredak's avatar
Przemek Tredak committed
592
    }
593
594
595
596
597
598
599
600
    __syncthreads();
    T out = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < WARPS_N; it++) {
      out = op(out, smem[it]);
    }
    return out;
  }
Przemek Tredak's avatar
Przemek Tredak committed
601

602
603
604
605
606
607
608
609
  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
    T *const smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    // only intra-CTA group leader holds the result!
    data = Base::reduce(data, op);
    if (this->lane_ == 0) {
      smem[this->warp_n_] = data;
Przemek Tredak's avatar
Przemek Tredak committed
610
    }
611
612
613
614
615
616
617
618
619
620
    __syncthreads();
    T out = Zeros<T>::get();
    if (this->warp_n_ == 0 && this->lane_ == 0) {
#pragma unroll
      for (int it = 0; it < WARPS_N; it++) {
        out = op(out, smem[it]);
      }
    }
    return out;
  }
Przemek Tredak's avatar
Przemek Tredak committed
621

622
623
624
  T *const smem0_;
  T *const smem1_;
  bool use0_;
Przemek Tredak's avatar
Przemek Tredak committed
625
626
627
628
};

////////////////////////////////////////////////////////////////////////////////////////////////////

629
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
630
struct DynamicReducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
  using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
  using Type = typename Base::Type;

  template <typename Params>
  inline __device__ DynamicReducer(const Params &params, uint32_t bidm, uint32_t bidn,
                                   uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        inter_cta_(params.barrier, bidm, params.ctas_per_col, params.ctas_per_row),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * params.ctas_per_row),
        w1_(w0_ + params.ctas_per_col * WARPS_M * params.ctas_per_row) {}

  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    // Trivial case
    if (inter_cta_.group_size_ == 1) {
      return Base::allreduce(data, op);
    }

    data = Base::reduce(data, op);
    // We switch workspace every iteration.
    T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;

    // Warp leaders 0 hold the CTA-local results.
    if (this->warp_n_ == 0 && this->lane_ == 0) {
      workspace[bidn_] = data;
    }
    inter_cta_.sync();
    T total = Zeros<T>::get();
    for (int it = this->lane_; it < inter_cta_.group_size_; it += THREADS_PER_WARP) {
      total = op(total, workspace[it]);
    }
    total = Reducer<T, 1, 1, 1>::allreduce_(total, op);

    return total;
  }

  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
    return allreduce(data, op);
  }

  InterCTASync inter_cta_;

  T *const w0_;
  T *const w1_;
  int bidn_;
Przemek Tredak's avatar
Przemek Tredak committed
679
680
681
682
};

////////////////////////////////////////////////////////////////////////////////////////////////////

683
684
685
686
687
688
689
690
691
692
693
694
695
/*
This is an implementation of the parallel Welford algorithm for incrementally computing variance

This algorithm is known as Chan's update formulae (Chat et al '79):
http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf

An introduction is provided by Wikipedia here:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance?section=5#Parallel_algorithm

A detailed reference on the exact version implemented (with better numerical stability) is provided here:
https://dbs.ifi.uni-heidelberg.de/files/Team/eschubert/publications/SSDBM18-covariance-authorcopy.pdf
*/

696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
template <typename T>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a,
                                             int num_active) {  // NOLINT(*)
  // Assume at least leftmost is valid and
  // init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
  int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);

#pragma unroll
  for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) {
    // Exchange
    T n_b = warp_shuffle_down(n_a, step);
    T m_b = warp_shuffle_down(m_a, step);
    T m2_b = warp_shuffle_down(m2_a, step);

    // Update
    const T n_ab = n_a + n_b;  // We can handle one of them being 0, not both.
    // Might have different n per thread, otherwise this would simplify :(
    const T rn_ab = 1.f / n_ab;
    const T delta = m_a - m_b;
    const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
    const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;

    n_a = n_ab;
    m_a = m_ab;
    m2_a = m2_ab;
  }
  // Intra-warp broadcast (only lane 0 has valid stats).
yuguo's avatar
yuguo committed
723
724
725
726
#ifdef __HIP_PLATFORM_AMD__
  m_a = __shfl(m_a, 0, THREADS_PER_WARP);
  m2_a = __shfl(m2_a, 0, THREADS_PER_WARP);
#else
727
728
  m_a = __shfl_sync(static_cast<uint32_t>(-1), m_a, 0);
  m2_a = __shfl_sync(static_cast<uint32_t>(-1), m2_a, 0);
yuguo's avatar
yuguo committed
729
#endif
Przemek Tredak's avatar
Przemek Tredak committed
730
731
732
733
}

////////////////////////////////////////////////////////////////////////////////////////////////////

734
template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
735
struct Stats {
736
737
  // This could be done generically with the Reducer. But then we
  // would have to exchange 3 instead of 2 fields.
Przemek Tredak's avatar
Przemek Tredak committed
738

739
740
  using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
  using stats_t = typename BlockStats::stats_t;
Przemek Tredak's avatar
Przemek Tredak committed
741

742
  enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
Przemek Tredak's avatar
Przemek Tredak committed
743

744
745
746
747
748
749
750
751
752
753
754
  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
        block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<stats_t *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW),
        warp_n_(warp_n),
        lane_(lane) {}
Przemek Tredak's avatar
Przemek Tredak committed
755

756
757
758
759
760
761
  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
    // TODO(ptredak) rn is not really needed here..
    constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
    stats_t block_stats = block_stats_.compute(elts, block_rn);
Przemek Tredak's avatar
Przemek Tredak committed
762

763
    stats_t *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
Przemek Tredak's avatar
Przemek Tredak committed
764

765
766
    if (warp_n_ == 0 && lane_ == 0) {
      workspace[bidn_] = block_stats;
Przemek Tredak's avatar
Przemek Tredak committed
767
768
    }

769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
    // Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
    inter_cta_.sync();

    T n = Zeros<T>::get();
    T m = Zeros<T>::get();
    T m2 = Zeros<T>::get();

    // Assume CTA group size in N less than 32, such that we can finalize with a single warp.
    static_assert(CTAS_PER_ROW <= 32);

    // Every warp does the final reduction locally.
    if (lane_ < CTAS_PER_ROW) {
      stats_t result = workspace[lane_];
      n = ELTS_PER_ROW_PER_CTA;
      m = transformer_engine::Get<0>::of<stats_t, T>(result);
      m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
    }

    warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);

    return {m, m2};
  }
Przemek Tredak's avatar
Przemek Tredak committed
791

792
793
794
795
796
797
798
799
  InterCTASync inter_cta_;
  BlockStats block_stats_;

  stats_t *const w0_;
  stats_t *const w1_;
  int bidn_;
  int warp_n_;
  int lane_;
Przemek Tredak's avatar
Przemek Tredak committed
800
801
802
803
};

////////////////////////////////////////////////////////////////////////////////////////////////////

804
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
805
struct Stats<T, 1, WARPS_M, WARPS_N> {
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
  using WarpStats = Stats<T, 1, WARPS_M, 1>;
  using stats_t = typename WarpStats::stats_t;

  enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };

  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) {
    smem0_ = static_cast<stats_t *>(smem) + warp_m * WARPS_N;
    smem1_ = smem0_ + WARPS_M * WARPS_N;
  }

  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    stats_t *smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    // Compute warp local for all WARPS_N
    constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
    stats_t warp_stats = warp_stats_.compute(elts, warp_rn);

    // Each warp warp leader stores its stats
    const auto warp_n = warp_stats_.reducer_.warp_n_;
    const auto lane = warp_stats_.reducer_.lane_;
    if (lane == 0) {
      smem[warp_n] = warp_stats;
Przemek Tredak's avatar
Przemek Tredak committed
832
    }
833
    __syncthreads();
Przemek Tredak's avatar
Przemek Tredak committed
834

835
836
837
    T n = Zeros<T>::get();
    T m = Zeros<T>::get();
    T m2 = Zeros<T>::get();
Przemek Tredak's avatar
Przemek Tredak committed
838

839
840
841
842
843
844
845
    // Assume that there are less than 32 warps, such that we can finalize with a single warp
    static_assert(WARPS_N <= 32);
    if (lane < WARPS_N) {
      stats_t result = smem[lane];
      n = N * THREADS_PER_WARP;
      m = transformer_engine::Get<0>::of<stats_t, T>(result);
      m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
Przemek Tredak's avatar
Przemek Tredak committed
846
    }
847
848
849
850
851
852
853
854
855

    warp_chan_upd_dynamic(m, m2, n, WARPS_N);

    return {m, m2};
  }
  WarpStats warp_stats_;
  stats_t *smem0_;
  stats_t *smem1_;
  bool use0_;
Przemek Tredak's avatar
Przemek Tredak committed
856
857
858
859
};

////////////////////////////////////////////////////////////////////////////////////////////////////

860
template <typename T, uint32_t WARPS_M>
Przemek Tredak's avatar
Przemek Tredak committed
861
struct Stats<T, 1, WARPS_M, 1> {
862
863
864
  using stats_t = typename TypeToVec2<T>::Type;
  // The simple Warp reducer.
  using Reducer = Reducer<T, 1, WARPS_M, 1>;
Przemek Tredak's avatar
Przemek Tredak committed
865

866
  enum { SMEM_BYTES = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
867

868
869
870
871
  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {}
Przemek Tredak's avatar
Przemek Tredak committed
872

873
874
875
  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    auto sum = Sum<T>();
Przemek Tredak's avatar
Przemek Tredak committed
876

877
878
879
880
881
882
    T m = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < N; it++) {
      m += elts[it];
    }
    m = reducer_.allreduce(m, sum) * rn;
Przemek Tredak's avatar
Przemek Tredak committed
883

884
885
886
887
888
    T m2 = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < N; it++) {
      T diff = (elts[it] - m);
      m2 += diff * diff;
Przemek Tredak's avatar
Przemek Tredak committed
889
    }
890
891
892
893
    m2 = reducer_.allreduce(m2, sum);

    return {m, m2};
  }
Przemek Tredak's avatar
Przemek Tredak committed
894

895
  Reducer reducer_;
Przemek Tredak's avatar
Przemek Tredak committed
896
897
898
899
900
901
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int num_elems>
__device__ __forceinline__ float warp_reduce_max(const float m) {
902
  float tmp = m;
Przemek Tredak's avatar
Przemek Tredak committed
903
#pragma unroll
904
  for (int delta = num_elems / 2; delta > 0; delta /= 2) {
yuguo's avatar
yuguo committed
905
906
907
#ifdef __HIP_PLATFORM_AMD__
    const float other_m = __shfl_down(tmp, delta, THREADS_PER_WARP);
#else
908
    const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta);
yuguo's avatar
yuguo committed
909
#endif
910
911
912
913
914
    __builtin_assume(tmp >= 0);
    __builtin_assume(other_m >= 0);
    tmp = fmaxf(tmp, other_m);
  }
  return tmp;
Przemek Tredak's avatar
Przemek Tredak committed
915
916
}

917
918
919
920
__forceinline__ __device__ float warp_reduce_max_broadcast(const float val) {
  float val_tmp = val;
#pragma unroll
  for (int offset = THREADS_PER_WARP / 2; offset > 0; offset /= 2) {
yuguo's avatar
yuguo committed
921
922
923
#ifdef __HIP_PLATFORM_AMD__
    const float val_other = __shfl_down(val_tmp, offset, THREADS_PER_WARP);
#else
924
    const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset);
yuguo's avatar
yuguo committed
925
#endif
926
927
928
929
930
931
    __builtin_assume(val_tmp >= 0);
    __builtin_assume(val_other >= 0);
    val_tmp = fmaxf(val_tmp, val_other);
  }
  // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
  constexpr int subwarp_lane_zero = 0;
yuguo's avatar
yuguo committed
932
933
934
#ifdef __HIP_PLATFORM_AMD__
  val_tmp = __shfl(val_tmp, subwarp_lane_zero, THREADS_PER_WARP);
#else
935
  val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero);
yuguo's avatar
yuguo committed
936
#endif
937
938
939
  return val_tmp;
}

Przemek Tredak's avatar
Przemek Tredak committed
940
941
template <int num_warps, typename compute_t>
__device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) {
942
943
944
945
946
947
948
949
  __shared__ float staging[num_warps];
  constexpr int warp_size = 32;
  const float my_max = m;
  const float my_warp_max = warp_reduce_max<warp_size>(my_max);
  if (threadIdx.x % 32 == 0) {
    staging[warpid] = my_warp_max;
  }
  __syncthreads();
950
  compute_t result = 0.f;
951
952
953
954
955
  if (warpid == 0) {
    const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0;
    result = warp_reduce_max<num_warps>(my_max);
  }
  return result;
Przemek Tredak's avatar
Przemek Tredak committed
956
957
}

958
959
960
961
962
963
964
965
966
967
968
969
/**
 * Max reduction in subwarps
 * E.g., if nvec=4, each warp processes 128 elements (32 x 4), that covers four MXFP8 scaling factors.
 * To compute an actual scaling factor for 32 consequentive elements, only 8 threads need to participate,
 * thus splitting the warp into 4x smaller subwarps 8-thread width.
 * 'Butterfly' reduction is used inside subwarps.
 */
template <int subwarp_width>
__forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) {
  float val_tmp = val;
#pragma unroll
  for (int offset = subwarp_width / 2; offset > 0; offset /= 2) {
yuguo's avatar
yuguo committed
970
971
972
#ifdef __HIP_PLATFORM_AMD__
    const float val_other = __shfl_down(val_tmp, offset, subwarp_width);
#else
973
    const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width);
yuguo's avatar
yuguo committed
974
#endif
975
976
977
978
979
980
    __builtin_assume(val_tmp >= 0);
    __builtin_assume(val_other >= 0);
    val_tmp = fmaxf(val_tmp, val_other);
  }
  // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
  constexpr int subwarp_lane_zero = 0;
yuguo's avatar
yuguo committed
981
982
983
#ifdef __HIP_PLATFORM_AMD__
  val_tmp = __shfl(val_tmp, subwarp_lane_zero, subwarp_width);
#else
984
  val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width);
yuguo's avatar
yuguo committed
985
#endif
986
987
988
  return val_tmp;
}

Przemek Tredak's avatar
Przemek Tredak committed
989
// Works only on positive values
990
991
__device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) {
  atomicMax(reinterpret_cast<int *>(addr), __float_as_int(value));
Przemek Tredak's avatar
Przemek Tredak committed
992
993
994
}

// Works only on positive values
995
996
__device__ __forceinline__ void atomicMinFloat(float *addr, const float value) {
  atomicMin(reinterpret_cast<int *>(addr), __float_as_int(value));
Przemek Tredak's avatar
Przemek Tredak committed
997
998
999
}

template <typename T>
1000
1001
__device__ __forceinline__ void reciprocal(T *value_inv, const T value) {
  *value_inv = 1 / value;
Przemek Tredak's avatar
Przemek Tredak committed
1002
1003
}

1004
1005
1006
1007
1008
template <>
__device__ __forceinline__ void reciprocal<float>(float *value_inv, const float value) {
  *value_inv = __frcp_rn(value);
}

1009
1010
1011
1012
1013
////////////////////////////////////////////////////////////////////////////////////////////////////

using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using e8m0_t = uint8_t;
yuguo's avatar
yuguo committed
1014
using int8 = int8_t;
1015

1016
enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENSIONAL = 2 };
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032

template <typename T>
struct Numeric_Traits;

template <>
struct Numeric_Traits<fp8e4m3> {
  static constexpr int maxUnbiasedExponent = 8;
  static constexpr double maxNorm = 448;
};

template <>
struct Numeric_Traits<fp8e5m2> {
  static constexpr int maxUnbiasedExponent = 15;
  static constexpr double maxNorm = 57344;
};

yuguo's avatar
yuguo committed
1033
1034
1035
1036
1037
1038
template <>
struct Numeric_Traits<int8> {
  static constexpr int maxUnbiasedExponent = 0;
  static constexpr double maxNorm = 127;
};

1039
1040
1041
1042
1043
1044
1045
1046
1047
template <typename T>
struct Quantized_Limits {
  static constexpr int max_unbiased_exponent = Numeric_Traits<T>::maxUnbiasedExponent;
  static constexpr float max_norm = Numeric_Traits<T>::maxNorm;
  static constexpr float max_norm_rcp = 1.0 / max_norm;
  static constexpr float emax = 1 << max_unbiased_exponent;
  static constexpr float emax_rcp = 1.0 / emax;
};

Przemek Tredak's avatar
Przemek Tredak committed
1048
1049
1050
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_