utils.cuh 31.6 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
#define TRANSFORMER_ENGINE_COMMON_UTILS_CUH_

#include <cuda_fp16.h>
yuguo's avatar
yuguo committed
11
12
13
14
15

#ifdef __HIP_PLATFORM_AMD__
#ifndef __HIPCC_RTC__
#include <cstdint>
#else
16
#include <hip/amd_detail/hip_assert.h>
yuguo's avatar
yuguo committed
17
18
19
20
21
using namespace __hip_internal;
#endif
#endif

#include <cuda_bf16.h>
Tim Moon's avatar
Tim Moon committed
22
#include <cuda_fp8.h>
Przemek Tredak's avatar
Przemek Tredak committed
23

24
25
26
27
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif

yuguo's avatar
yuguo committed
28
29
30
#ifdef __HIP_PLATFORM_AMD__
typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2)));
#else
Tim Moon's avatar
Tim Moon committed
31
32
33
34
35
36
37
38
39
40
41
42
43
#if !defined(__CUDACC_RTC__)
#include <cstdint>
#else
// Importing C++ standard headers is a pain with NVRTC
using uint8_t = unsigned char;
using uint16_t = unsigned short int;  // NOLINT(*)
using uint32_t = unsigned int;
using uint64_t = unsigned long long int;  // NOLINT(*)
static_assert(sizeof(uint8_t) == 1);
static_assert(sizeof(uint16_t) == 2);
static_assert(sizeof(uint32_t) == 4);
static_assert(sizeof(uint64_t) == 8);
#endif
Przemek Tredak's avatar
Przemek Tredak committed
44

yuguo's avatar
yuguo committed
45
#endif // __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
46
47
48
49
50
51
////////////////////////////////////////////////////////////////////////////////////////////////////

constexpr uint32_t THREADS_PER_WARP = 32;

////////////////////////////////////////////////////////////////////////////////////////////////////

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
// Device-side error
#define NVTE_DEVICE_ERROR(message)                                                                 \
  do {                                                                                             \
    printf("%s:%d in function %s (thread (%d,%d,%d), block (%d,%d,%d)): %s\n", __FILE__, __LINE__, \
           __func__, threadIdx.x, threadIdx.y, threadIdx.z, blockIdx.x, blockIdx.y, blockIdx.z,    \
           (message));                                                                             \
    assert(0);                                                                                     \
  } while (false)

// Device-side error on thread 0
#define NVTE_DEVICE_THREAD0_ERROR(message)                                           \
  do {                                                                               \
    if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && \
        threadIdx.y == 0 && threadIdx.z == 0) {                                      \
      NVTE_DEVICE_ERROR(message);                                                    \
    }                                                                                \
  } while (false)

////////////////////////////////////////////////////////////////////////////////////////////////////

yuguo's avatar
yuguo committed
72
#if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__)
73
74
inline __device__ float2 operator+(const float2 &a, const float2 &b) {  // NOLINT(*)
  return {a.x + b.x, a.y + b.y};
Przemek Tredak's avatar
Przemek Tredak committed
75
76
77
78
}

////////////////////////////////////////////////////////////////////////////////////////////////////

79
80
81
inline __device__ void operator+=(float2 &a, const float2 &b) {  // NOLINT(*)
  a.x += b.x;
  a.y += b.y;
Przemek Tredak's avatar
Przemek Tredak committed
82
}
yuguo's avatar
yuguo committed
83
#endif
Przemek Tredak's avatar
Przemek Tredak committed
84
85
86

////////////////////////////////////////////////////////////////////////////////////////////////////

87
template <typename T>
Przemek Tredak's avatar
Przemek Tredak committed
88
struct Sum {
89
90
  inline __device__ Sum() {}
  inline __device__ T operator()(const T &a, const T &b) const { return a + b; }
Przemek Tredak's avatar
Przemek Tredak committed
91
92
93
94
};

////////////////////////////////////////////////////////////////////////////////////////////////////

95
96
template <typename T>
inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) {
yuguo's avatar
yuguo committed
97
98
99
#ifdef __HIP_PLATFORM_AMD__
  return __shfl_xor(x, idx, THREADS_PER_WARP);
#else
100
  return __shfl_xor_sync(static_cast<uint32_t>(-1), x, idx);
yuguo's avatar
yuguo committed
101
#endif
Przemek Tredak's avatar
Przemek Tredak committed
102
103
}

104
105
106
template <>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 &x, uint32_t idx) {
  return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)};
Przemek Tredak's avatar
Przemek Tredak committed
107
108
}

109
110
template <typename T>
inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) {
yuguo's avatar
yuguo committed
111
112
113
#ifdef __HIP_PLATFORM_AMD__
  return __shfl_down(x, idx, THREADS_PER_WARP);
#else
114
  return __shfl_down_sync(static_cast<uint32_t>(-1), x, idx);
yuguo's avatar
yuguo committed
115
#endif
Przemek Tredak's avatar
Przemek Tredak committed
116
117
}

118
119
120
template <>
inline __device__ float2 warp_shuffle_down<float2>(const float2 &x, uint32_t idx) {
  return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)};
Przemek Tredak's avatar
Przemek Tredak committed
121
122
123
124
125
126
127
128
129
}

////////////////////////////////////////////////////////////////////////////////////////////////////

namespace transformer_engine {

////////////////////////////////////////////////////////////////////////////////////////////////////

struct uint16 {
130
131
132
133
  uint4 u;
  uint4 v;
  uint4 s;
  uint4 t;
Przemek Tredak's avatar
Przemek Tredak committed
134
135
136
137
138
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct uint8 {
139
140
  uint4 u;
  uint4 v;
Przemek Tredak's avatar
Przemek Tredak committed
141
142
143
144
};

////////////////////////////////////////////////////////////////////////////////////////////////////

145
template <int BYTES>
Przemek Tredak's avatar
Przemek Tredak committed
146
147
struct BytesToType {};

148
template <>
Przemek Tredak's avatar
Przemek Tredak committed
149
struct BytesToType<64> {
150
151
  using Type = uint16;
  static_assert(sizeof(Type) == 64);
Przemek Tredak's avatar
Przemek Tredak committed
152
153
};

154
template <>
Przemek Tredak's avatar
Przemek Tredak committed
155
struct BytesToType<32> {
156
157
  using Type = uint8;
  static_assert(sizeof(Type) == 32);
Przemek Tredak's avatar
Przemek Tredak committed
158
159
};

160
template <>
Przemek Tredak's avatar
Przemek Tredak committed
161
struct BytesToType<16> {
162
163
  using Type = uint4;
  static_assert(sizeof(Type) == 16);
Przemek Tredak's avatar
Przemek Tredak committed
164
165
};

166
template <>
Przemek Tredak's avatar
Przemek Tredak committed
167
struct BytesToType<8> {
168
169
  using Type = uint64_t;
  static_assert(sizeof(Type) == 8);
Przemek Tredak's avatar
Przemek Tredak committed
170
171
};

172
template <>
Przemek Tredak's avatar
Przemek Tredak committed
173
struct BytesToType<4> {
174
175
  using Type = uint32_t;
  static_assert(sizeof(Type) == 4);
Przemek Tredak's avatar
Przemek Tredak committed
176
177
};

178
template <>
Przemek Tredak's avatar
Przemek Tredak committed
179
struct BytesToType<2> {
180
181
  using Type = uint16_t;
  static_assert(sizeof(Type) == 2);
Przemek Tredak's avatar
Przemek Tredak committed
182
183
};

184
template <>
Przemek Tredak's avatar
Przemek Tredak committed
185
struct BytesToType<1> {
186
187
  using Type = uint8_t;
  static_assert(sizeof(Type) == 1);
Przemek Tredak's avatar
Przemek Tredak committed
188
189
190
191
};

////////////////////////////////////////////////////////////////////////////////////////////////////

192
template <typename T>
Przemek Tredak's avatar
Przemek Tredak committed
193
194
struct TypeToVec2 {};

195
template <>
Przemek Tredak's avatar
Przemek Tredak committed
196
struct TypeToVec2<float> {
197
  using Type = float2;
Przemek Tredak's avatar
Przemek Tredak committed
198
199
};

200
template <>
Przemek Tredak's avatar
Przemek Tredak committed
201
struct TypeToVec2<half> {
202
  using Type = half2;
Przemek Tredak's avatar
Przemek Tredak committed
203
204
};

yuguo's avatar
yuguo committed
205
206
207
208
209
210
#ifdef __HIP_PLATFORM_AMD__
template <>
struct TypeToVec2<__hip_bfloat16> {
  using Type = hip_bfloat16x2;
};
#else
211
template <>
Przemek Tredak's avatar
Przemek Tredak committed
212
struct TypeToVec2<nv_bfloat16> {
213
  using Type = nv_bfloat162;
Przemek Tredak's avatar
Przemek Tredak committed
214
};
yuguo's avatar
yuguo committed
215
#endif
Przemek Tredak's avatar
Przemek Tredak committed
216
217
218

////////////////////////////////////////////////////////////////////////////////////////////////////

219
220
template <typename IType, typename IType2, typename OType, typename CType>
struct CTDBiasDActParam {
221
222
223
224
225
226
227
228
229
230
231
232
233
  using InputType = IType;
  using InputType2 = IType2;
  using OutputType = OType;
  using ComputeType = CType;
  const IType *input;
  const IType2 *act_input;
  OType *output_c;
  OType *output_t;
  const CType *scale_ptr;
  CType *amax;
  CType *scale_inv;
  CType *workspace;
  CType *warp_scales_inv;
234
235
236
237
};

////////////////////////////////////////////////////////////////////////////////////////////////////

238
template <int INDEX>
Przemek Tredak's avatar
Przemek Tredak committed
239
struct Get {
240
241
  template <typename T, typename R>
  static inline __device__ R of(const T &vec);
Przemek Tredak's avatar
Przemek Tredak committed
242
243
};

244
245
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
246
inline __device__ R Get<0>::of(const T &vec) {
247
  return vec.x;
Przemek Tredak's avatar
Przemek Tredak committed
248
249
}

250
251
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
252
inline __device__ R Get<1>::of(const T &vec) {
253
  return vec.y;
Przemek Tredak's avatar
Przemek Tredak committed
254
255
}

256
257
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
258
inline __device__ R Get<2>::of(const T &vec) {
259
  return vec.z;
Przemek Tredak's avatar
Przemek Tredak committed
260
261
}

262
263
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
264
inline __device__ R Get<3>::of(const T &vec) {
265
  return vec.w;
Przemek Tredak's avatar
Przemek Tredak committed
266
267
268
269
}

////////////////////////////////////////////////////////////////////////////////////////////////////

270
271
272
template <typename Src, typename Dst>
struct Converter {
  static inline __device__ Dst convert(const Src &from) { return Dst(from); }
Przemek Tredak's avatar
Przemek Tredak committed
273
274
};

275
276
277
template <>
struct Converter<float2, half2> {
  static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); }
Przemek Tredak's avatar
Przemek Tredak committed
278
279
};

yuguo's avatar
yuguo committed
280
281
282
283
284
285
#ifdef __HIP_PLATFORM_AMD__
template <>
struct Converter<float2, hip_bfloat16x2> {
  static inline __device__ hip_bfloat16x2 convert(const float2 &x) {
    union {
      hip_bfloat16x2 raw;
yuguo's avatar
yuguo committed
286
      __hip_bfloat16 elt[2];
yuguo's avatar
yuguo committed
287
288
289
290
291
292
293
    } tmp;
    tmp.elt[0] = __hip_bfloat16(x.x);
    tmp.elt[1] = __hip_bfloat16(x.y);
    return tmp.raw;
  }
};
#else
294
295
296
template <>
struct Converter<float2, nv_bfloat162> {
  static inline __device__ nv_bfloat162 convert(const float2 &x) {
Przemek Tredak's avatar
Przemek Tredak committed
297
#if __CUDA_ARCH__ >= 800
298
    return __float22bfloat162_rn(x);
Przemek Tredak's avatar
Przemek Tredak committed
299
#else
300
301
302
303
304
305
306
    union {
      nv_bfloat162 raw;
      nv_bfloat16 elt[2];
    } tmp;
    tmp.elt[0] = __float2bfloat16_rn(x.x);
    tmp.elt[1] = __float2bfloat16_rn(x.y);
    return tmp.raw;
Przemek Tredak's avatar
Przemek Tredak committed
307
#endif
308
  }
Przemek Tredak's avatar
Przemek Tredak committed
309
};
yuguo's avatar
yuguo committed
310
#endif
Przemek Tredak's avatar
Przemek Tredak committed
311
312
313

////////////////////////////////////////////////////////////////////////////////////////////////////

314
315
316
template <typename T>
struct Zeros {
  static inline __device__ T get() { return T(0.f); }
Przemek Tredak's avatar
Przemek Tredak committed
317
318
};

319
320
321
template <>
struct Zeros<float2> {
  static inline __device__ float2 get() { return make_float2(0.f, 0.f); }
Przemek Tredak's avatar
Przemek Tredak committed
322
323
324
325
};

////////////////////////////////////////////////////////////////////////////////////////////////////

326
template <typename Elt_type, uint32_t NUM_ELT>
Przemek Tredak's avatar
Przemek Tredak committed
327
struct Vec {
328
  enum { BYTES = NUM_ELT * sizeof(Elt_type) };
Przemek Tredak's avatar
Przemek Tredak committed
329

330
331
  using Vec_type = typename BytesToType<BYTES>::Type;
  using type = Elt_type;
Przemek Tredak's avatar
Przemek Tredak committed
332

333
334
335
336
  using Alias_type = union {
    Vec_type vec;
    Elt_type elt[NUM_ELT];
  };
Przemek Tredak's avatar
Przemek Tredak committed
337

338
  Alias_type data;
yuguo's avatar
yuguo committed
339
340
341
342
343
344
#ifdef __HIP_PLATFORM_AMD__
  __HOST_DEVICE__ Vec& operator=(const Vec& rhs) {
    data.vec = rhs.data.vec;
    return *this;
  }
#endif
Przemek Tredak's avatar
Przemek Tredak committed
345

346
347
348
349
350
  template <typename S>
  inline __device__ void to(Vec<S, NUM_ELT> &other) {  // NOLINT(*)
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      other.data.elt[it] = S(this->data.elt[it]);
Przemek Tredak's avatar
Przemek Tredak committed
351
    }
352
  }
Przemek Tredak's avatar
Przemek Tredak committed
353

354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
  template <typename Op>
  inline __device__ void assign(const Op &op) {
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      this->data.elt[it] = op(it);
    }
  }

  // Pointer is cast to vector type
  inline __device__ void load_from(const void *base_ptr, size_t idx = 0) {
    this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
  }

  // Pointer is cast to vector type
  inline __device__ void store_to(void *base_ptr, size_t idx = 0) const {
    static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
  }

  // Pointer is cast to element type. Loads min(count, NUM_ELT)
  // elements and any remaining elements are set to zero.
  inline __device__ void load_from_elts(const void *base_ptr, size_t idx = 0,
                                        size_t count = NUM_ELT) {
    const Elt_type *elt_ptr = static_cast<const Elt_type *>(base_ptr) + idx;
    if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
      for (int it = 0; it < NUM_ELT; it++) {
        this->data.elt[it] = (it < count ? elt_ptr[it] : Elt_type(0.f));
      }
    } else {
      this->load_from(elt_ptr);
    }
  }

  // Pointer is cast to element type. Stores min(count, NUM_ELT)
  // elements.
  inline __device__ void store_to_elts(void *base_ptr, size_t idx = 0,
                                       size_t count = NUM_ELT) const {
    Elt_type *elt_ptr = static_cast<Elt_type *>(base_ptr) + idx;
    if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
      for (int it = 0; it < NUM_ELT; it++) {
        if (it < count) {
          elt_ptr[it] = this->data.elt[it];
Przemek Tredak's avatar
Przemek Tredak committed
397
        }
398
399
400
      }
    } else {
      this->store_to(elt_ptr);
Przemek Tredak's avatar
Przemek Tredak committed
401
    }
402
  }
Przemek Tredak's avatar
Przemek Tredak committed
403

404
405
406
407
  inline __device__ void clear() {
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      this->data.elt[it] = Elt_type(0.f);
Przemek Tredak's avatar
Przemek Tredak committed
408
    }
409
  }
Przemek Tredak's avatar
Przemek Tredak committed
410
411
412
413
414
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct InterCTASync {
415
416
417
418
419
420
421
422
423
424
  inline __device__ InterCTASync(int *barrier, int group, int num_groups, int group_size)
      : phase_counter_(0),
        b0_(barrier + group)  // The barrier for this group of CTAs.
        ,
        b1_(barrier + group + num_groups)  // The barrier for this group of CTAs.
        ,
        group_size_(group_size) {
    // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
  }

yuguo's avatar
yuguo committed
425
426
427
428
429
430
431
432
#ifdef __HIP_PLATFORM_AMD__
  inline __device__ void spin_wait_(int *barrier, int step, int expected) {
    __hip_atomic_fetch_add(barrier, step, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
    for (int found = -1; found != expected; ) {
      found = __hip_atomic_load(barrier, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT);
    }
  }
#else
433
434
435
436
437
438
  inline __device__ void spin_wait_(int *barrier, int step, int expected) {
    asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
    for (int found = -1; found != expected;) {
      asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
    }
  }
yuguo's avatar
yuguo committed
439
#endif
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458

  inline __device__ void sync() {
    // ALL THREADS MUST ENTER!

    // We switch barrier every iteration.
    int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
    // We decrement every other iteration.
    bool dec = phase_counter_ & 0x2;
    int step = dec ? -1 : 1;
    int expected = dec ? 0 : group_size_;
    // There are only 4 phases: up/down for b0/b1.
    phase_counter_ = (phase_counter_ + 1) & 0x3;

    if (threadIdx.x == 0) {
      spin_wait_(barrier, step, expected);
    }
    // CTA waits for thread 0
    __syncthreads();
  }
Przemek Tredak's avatar
Przemek Tredak committed
459

460
461
462
463
  int phase_counter_;
  int *b0_;
  int *b1_;
  int group_size_;
Przemek Tredak's avatar
Przemek Tredak committed
464
465
466
467
};

////////////////////////////////////////////////////////////////////////////////////////////////////

468
template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
469
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
  using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
  using Type = typename Base::Type;

  enum { SMEM_BYTES = Base::SMEM_BYTES };

  enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
  enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };

  // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
  enum {
    WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES
  };

  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {}

  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    data = Base::reduce(data, op);
    // We switch workspace every iteration.
    T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;

    // Warp leaders 0 hold the CTA-local results.
    if (this->warp_n_ == 0 && this->lane_ == 0) {
      workspace[bidn_] = data;
    }
    inter_cta_.sync();
    static_assert(CTAS_PER_ROW <= 32);
    T total = Zeros<T>::get();
    if (this->lane_ < CTAS_PER_ROW) {
      total = workspace[this->lane_];
    }
    total = Reducer<T, 1, 1, 1>::allreduce_(total, op);

    return total;
  }

  InterCTASync inter_cta_;

  T *const w0_;
  T *const w1_;
  int bidn_;
Przemek Tredak's avatar
Przemek Tredak committed
519
520
521
522
};

////////////////////////////////////////////////////////////////////////////////////////////////////

523
template <typename T, uint32_t WARPS_M>
Przemek Tredak's avatar
Przemek Tredak committed
524
struct Reducer<T, 1, WARPS_M, 1> {
525
526
527
  using Type = T;
  enum { SMEM_BYTES = 0 };
  enum { WORKSPACE_BYTES_PER_GROUP = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
528

529
530
531
532
533
534
535
536
537
538
539
540
  enum { THREADS_PER_WARP = 32 };

  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : warp_n_(warp_n), lane_(lane) {}

  template <typename Op>
  static inline __device__ T allreduce_(T data, const Op &op) {
#pragma unroll
    for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
      data = op(data, warp_shuffle_xor(data, it));
Przemek Tredak's avatar
Przemek Tredak committed
541
    }
542
543
    return data;
  }
Przemek Tredak's avatar
Przemek Tredak committed
544

545
546
547
548
549
550
551
552
553
554
555
  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    return allreduce_(data, op);
  }

  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
// only lane 0 holds the result!
#pragma unroll
    for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) {
      data = op(data, warp_shuffle_down(data, it));
Przemek Tredak's avatar
Przemek Tredak committed
556
    }
557
558
559
560
    return data;
  }
  int warp_n_;
  int lane_;
Przemek Tredak's avatar
Przemek Tredak committed
561
562
563
564
};

////////////////////////////////////////////////////////////////////////////////////////////////////

565
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
566
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
567
  using Base = Reducer<T, 1, WARPS_M, 1>;
Przemek Tredak's avatar
Przemek Tredak committed
568

569
  using Type = T;
Przemek Tredak's avatar
Przemek Tredak committed
570

571
572
  enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
  enum { WORKSPACE_BYTES_PER_GROUP = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
573

574
  enum { THREADS_PER_WARP = 32 };
Przemek Tredak's avatar
Przemek Tredak committed
575

576
577
578
579
580
581
582
  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        use0_(true),
        smem0_(&(static_cast<T *>(smem)[warp_m * WARPS_N])),
        smem1_(smem0_ + WARPS_M * WARPS_N) {}
Przemek Tredak's avatar
Przemek Tredak committed
583

584
585
586
587
588
589
590
  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    T *const smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    data = Base::reduce(data, op);
    if (this->lane_ == 0) {
      smem[this->warp_n_] = data;
Przemek Tredak's avatar
Przemek Tredak committed
591
    }
592
593
594
595
596
597
598
599
    __syncthreads();
    T out = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < WARPS_N; it++) {
      out = op(out, smem[it]);
    }
    return out;
  }
Przemek Tredak's avatar
Przemek Tredak committed
600

601
602
603
604
605
606
607
608
  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
    T *const smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    // only intra-CTA group leader holds the result!
    data = Base::reduce(data, op);
    if (this->lane_ == 0) {
      smem[this->warp_n_] = data;
Przemek Tredak's avatar
Przemek Tredak committed
609
    }
610
611
612
613
614
615
616
617
618
619
    __syncthreads();
    T out = Zeros<T>::get();
    if (this->warp_n_ == 0 && this->lane_ == 0) {
#pragma unroll
      for (int it = 0; it < WARPS_N; it++) {
        out = op(out, smem[it]);
      }
    }
    return out;
  }
Przemek Tredak's avatar
Przemek Tredak committed
620

621
622
623
  T *const smem0_;
  T *const smem1_;
  bool use0_;
Przemek Tredak's avatar
Przemek Tredak committed
624
625
626
627
};

////////////////////////////////////////////////////////////////////////////////////////////////////

628
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
629
struct DynamicReducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
  using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
  using Type = typename Base::Type;

  template <typename Params>
  inline __device__ DynamicReducer(const Params &params, uint32_t bidm, uint32_t bidn,
                                   uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        inter_cta_(params.barrier, bidm, params.ctas_per_col, params.ctas_per_row),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * params.ctas_per_row),
        w1_(w0_ + params.ctas_per_col * WARPS_M * params.ctas_per_row) {}

  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    // Trivial case
    if (inter_cta_.group_size_ == 1) {
      return Base::allreduce(data, op);
    }

    data = Base::reduce(data, op);
    // We switch workspace every iteration.
    T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;

    // Warp leaders 0 hold the CTA-local results.
    if (this->warp_n_ == 0 && this->lane_ == 0) {
      workspace[bidn_] = data;
    }
    inter_cta_.sync();
    T total = Zeros<T>::get();
    for (int it = this->lane_; it < inter_cta_.group_size_; it += THREADS_PER_WARP) {
      total = op(total, workspace[it]);
    }
    total = Reducer<T, 1, 1, 1>::allreduce_(total, op);

    return total;
  }

  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
    return allreduce(data, op);
  }

  InterCTASync inter_cta_;

  T *const w0_;
  T *const w1_;
  int bidn_;
Przemek Tredak's avatar
Przemek Tredak committed
678
679
680
681
};

////////////////////////////////////////////////////////////////////////////////////////////////////

682
683
684
685
686
687
688
689
690
691
692
693
694
/*
This is an implementation of the parallel Welford algorithm for incrementally computing variance

This algorithm is known as Chan's update formulae (Chat et al '79):
http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf

An introduction is provided by Wikipedia here:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance?section=5#Parallel_algorithm

A detailed reference on the exact version implemented (with better numerical stability) is provided here:
https://dbs.ifi.uni-heidelberg.de/files/Team/eschubert/publications/SSDBM18-covariance-authorcopy.pdf
*/

695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
template <typename T>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a,
                                             int num_active) {  // NOLINT(*)
  // Assume at least leftmost is valid and
  // init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
  int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);

#pragma unroll
  for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) {
    // Exchange
    T n_b = warp_shuffle_down(n_a, step);
    T m_b = warp_shuffle_down(m_a, step);
    T m2_b = warp_shuffle_down(m2_a, step);

    // Update
    const T n_ab = n_a + n_b;  // We can handle one of them being 0, not both.
    // Might have different n per thread, otherwise this would simplify :(
    const T rn_ab = 1.f / n_ab;
    const T delta = m_a - m_b;
    const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
    const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;

    n_a = n_ab;
    m_a = m_ab;
    m2_a = m2_ab;
  }
  // Intra-warp broadcast (only lane 0 has valid stats).
yuguo's avatar
yuguo committed
722
723
724
725
#ifdef __HIP_PLATFORM_AMD__
  m_a = __shfl(m_a, 0, THREADS_PER_WARP);
  m2_a = __shfl(m2_a, 0, THREADS_PER_WARP);
#else
726
727
  m_a = __shfl_sync(static_cast<uint32_t>(-1), m_a, 0);
  m2_a = __shfl_sync(static_cast<uint32_t>(-1), m2_a, 0);
yuguo's avatar
yuguo committed
728
#endif
Przemek Tredak's avatar
Przemek Tredak committed
729
730
731
732
}

////////////////////////////////////////////////////////////////////////////////////////////////////

733
template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
734
struct Stats {
735
736
  // This could be done generically with the Reducer. But then we
  // would have to exchange 3 instead of 2 fields.
Przemek Tredak's avatar
Przemek Tredak committed
737

738
739
  using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
  using stats_t = typename BlockStats::stats_t;
Przemek Tredak's avatar
Przemek Tredak committed
740

741
  enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
Przemek Tredak's avatar
Przemek Tredak committed
742

743
744
745
746
747
748
749
750
751
752
753
  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
        block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<stats_t *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW),
        warp_n_(warp_n),
        lane_(lane) {}
Przemek Tredak's avatar
Przemek Tredak committed
754

755
756
757
758
759
760
  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
    // TODO(ptredak) rn is not really needed here..
    constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
    stats_t block_stats = block_stats_.compute(elts, block_rn);
Przemek Tredak's avatar
Przemek Tredak committed
761

762
    stats_t *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
Przemek Tredak's avatar
Przemek Tredak committed
763

764
765
    if (warp_n_ == 0 && lane_ == 0) {
      workspace[bidn_] = block_stats;
Przemek Tredak's avatar
Przemek Tredak committed
766
767
    }

768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
    // Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
    inter_cta_.sync();

    T n = Zeros<T>::get();
    T m = Zeros<T>::get();
    T m2 = Zeros<T>::get();

    // Assume CTA group size in N less than 32, such that we can finalize with a single warp.
    static_assert(CTAS_PER_ROW <= 32);

    // Every warp does the final reduction locally.
    if (lane_ < CTAS_PER_ROW) {
      stats_t result = workspace[lane_];
      n = ELTS_PER_ROW_PER_CTA;
      m = transformer_engine::Get<0>::of<stats_t, T>(result);
      m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
    }

    warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);

    return {m, m2};
  }
Przemek Tredak's avatar
Przemek Tredak committed
790

791
792
793
794
795
796
797
798
  InterCTASync inter_cta_;
  BlockStats block_stats_;

  stats_t *const w0_;
  stats_t *const w1_;
  int bidn_;
  int warp_n_;
  int lane_;
Przemek Tredak's avatar
Przemek Tredak committed
799
800
801
802
};

////////////////////////////////////////////////////////////////////////////////////////////////////

803
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
804
struct Stats<T, 1, WARPS_M, WARPS_N> {
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
  using WarpStats = Stats<T, 1, WARPS_M, 1>;
  using stats_t = typename WarpStats::stats_t;

  enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };

  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) {
    smem0_ = static_cast<stats_t *>(smem) + warp_m * WARPS_N;
    smem1_ = smem0_ + WARPS_M * WARPS_N;
  }

  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    stats_t *smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    // Compute warp local for all WARPS_N
    constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
    stats_t warp_stats = warp_stats_.compute(elts, warp_rn);

    // Each warp warp leader stores its stats
    const auto warp_n = warp_stats_.reducer_.warp_n_;
    const auto lane = warp_stats_.reducer_.lane_;
    if (lane == 0) {
      smem[warp_n] = warp_stats;
Przemek Tredak's avatar
Przemek Tredak committed
831
    }
832
    __syncthreads();
Przemek Tredak's avatar
Przemek Tredak committed
833

834
835
836
    T n = Zeros<T>::get();
    T m = Zeros<T>::get();
    T m2 = Zeros<T>::get();
Przemek Tredak's avatar
Przemek Tredak committed
837

838
839
840
841
842
843
844
    // Assume that there are less than 32 warps, such that we can finalize with a single warp
    static_assert(WARPS_N <= 32);
    if (lane < WARPS_N) {
      stats_t result = smem[lane];
      n = N * THREADS_PER_WARP;
      m = transformer_engine::Get<0>::of<stats_t, T>(result);
      m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
Przemek Tredak's avatar
Przemek Tredak committed
845
    }
846
847
848
849
850
851
852
853
854

    warp_chan_upd_dynamic(m, m2, n, WARPS_N);

    return {m, m2};
  }
  WarpStats warp_stats_;
  stats_t *smem0_;
  stats_t *smem1_;
  bool use0_;
Przemek Tredak's avatar
Przemek Tredak committed
855
856
857
858
};

////////////////////////////////////////////////////////////////////////////////////////////////////

859
template <typename T, uint32_t WARPS_M>
Przemek Tredak's avatar
Przemek Tredak committed
860
struct Stats<T, 1, WARPS_M, 1> {
861
862
863
  using stats_t = typename TypeToVec2<T>::Type;
  // The simple Warp reducer.
  using Reducer = Reducer<T, 1, WARPS_M, 1>;
Przemek Tredak's avatar
Przemek Tredak committed
864

865
  enum { SMEM_BYTES = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
866

867
868
869
870
  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {}
Przemek Tredak's avatar
Przemek Tredak committed
871

872
873
874
  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    auto sum = Sum<T>();
Przemek Tredak's avatar
Przemek Tredak committed
875

876
877
878
879
880
881
    T m = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < N; it++) {
      m += elts[it];
    }
    m = reducer_.allreduce(m, sum) * rn;
Przemek Tredak's avatar
Przemek Tredak committed
882

883
884
885
886
887
    T m2 = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < N; it++) {
      T diff = (elts[it] - m);
      m2 += diff * diff;
Przemek Tredak's avatar
Przemek Tredak committed
888
    }
889
890
891
892
    m2 = reducer_.allreduce(m2, sum);

    return {m, m2};
  }
Przemek Tredak's avatar
Przemek Tredak committed
893

894
  Reducer reducer_;
Przemek Tredak's avatar
Przemek Tredak committed
895
896
897
898
899
900
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int num_elems>
__device__ __forceinline__ float warp_reduce_max(const float m) {
901
  float tmp = m;
Przemek Tredak's avatar
Przemek Tredak committed
902
#pragma unroll
903
  for (int delta = num_elems / 2; delta > 0; delta /= 2) {
yuguo's avatar
yuguo committed
904
905
906
#ifdef __HIP_PLATFORM_AMD__
    const float other_m = __shfl_down(tmp, delta, THREADS_PER_WARP);
#else
907
    const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta);
yuguo's avatar
yuguo committed
908
#endif
909
910
911
912
913
    __builtin_assume(tmp >= 0);
    __builtin_assume(other_m >= 0);
    tmp = fmaxf(tmp, other_m);
  }
  return tmp;
Przemek Tredak's avatar
Przemek Tredak committed
914
915
}

916
917
918
919
__forceinline__ __device__ float warp_reduce_max_broadcast(const float val) {
  float val_tmp = val;
#pragma unroll
  for (int offset = THREADS_PER_WARP / 2; offset > 0; offset /= 2) {
yuguo's avatar
yuguo committed
920
921
922
#ifdef __HIP_PLATFORM_AMD__
    const float val_other = __shfl_down(val_tmp, offset, THREADS_PER_WARP);
#else
923
    const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset);
yuguo's avatar
yuguo committed
924
#endif
925
926
927
928
929
930
    __builtin_assume(val_tmp >= 0);
    __builtin_assume(val_other >= 0);
    val_tmp = fmaxf(val_tmp, val_other);
  }
  // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
  constexpr int subwarp_lane_zero = 0;
yuguo's avatar
yuguo committed
931
932
933
#ifdef __HIP_PLATFORM_AMD__
  val_tmp = __shfl(val_tmp, subwarp_lane_zero, THREADS_PER_WARP);
#else
934
  val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero);
yuguo's avatar
yuguo committed
935
#endif
936
937
938
  return val_tmp;
}

Przemek Tredak's avatar
Przemek Tredak committed
939
940
template <int num_warps, typename compute_t>
__device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) {
941
942
943
944
945
946
947
948
  __shared__ float staging[num_warps];
  constexpr int warp_size = 32;
  const float my_max = m;
  const float my_warp_max = warp_reduce_max<warp_size>(my_max);
  if (threadIdx.x % 32 == 0) {
    staging[warpid] = my_warp_max;
  }
  __syncthreads();
949
  compute_t result = 0.f;
950
951
952
953
954
  if (warpid == 0) {
    const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0;
    result = warp_reduce_max<num_warps>(my_max);
  }
  return result;
Przemek Tredak's avatar
Przemek Tredak committed
955
956
}

957
958
959
960
961
962
963
964
965
966
967
968
/**
 * Max reduction in subwarps
 * E.g., if nvec=4, each warp processes 128 elements (32 x 4), that covers four MXFP8 scaling factors.
 * To compute an actual scaling factor for 32 consequentive elements, only 8 threads need to participate,
 * thus splitting the warp into 4x smaller subwarps 8-thread width.
 * 'Butterfly' reduction is used inside subwarps.
 */
template <int subwarp_width>
__forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) {
  float val_tmp = val;
#pragma unroll
  for (int offset = subwarp_width / 2; offset > 0; offset /= 2) {
yuguo's avatar
yuguo committed
969
970
971
#ifdef __HIP_PLATFORM_AMD__
    const float val_other = __shfl_down(val_tmp, offset, subwarp_width);
#else
972
    const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width);
yuguo's avatar
yuguo committed
973
#endif
974
975
976
977
978
979
    __builtin_assume(val_tmp >= 0);
    __builtin_assume(val_other >= 0);
    val_tmp = fmaxf(val_tmp, val_other);
  }
  // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
  constexpr int subwarp_lane_zero = 0;
yuguo's avatar
yuguo committed
980
981
982
#ifdef __HIP_PLATFORM_AMD__
  val_tmp = __shfl(val_tmp, subwarp_lane_zero, subwarp_width);
#else
983
  val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width);
yuguo's avatar
yuguo committed
984
#endif
985
986
987
  return val_tmp;
}

Przemek Tredak's avatar
Przemek Tredak committed
988
// Works only on positive values
989
990
__device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) {
  atomicMax(reinterpret_cast<int *>(addr), __float_as_int(value));
Przemek Tredak's avatar
Przemek Tredak committed
991
992
993
}

// Works only on positive values
994
995
__device__ __forceinline__ void atomicMinFloat(float *addr, const float value) {
  atomicMin(reinterpret_cast<int *>(addr), __float_as_int(value));
Przemek Tredak's avatar
Przemek Tredak committed
996
997
998
}

template <typename T>
999
1000
__device__ __forceinline__ void reciprocal(T *value_inv, const T value) {
  *value_inv = 1 / value;
Przemek Tredak's avatar
Przemek Tredak committed
1001
1002
}

1003
1004
1005
1006
1007
template <>
__device__ __forceinline__ void reciprocal<float>(float *value_inv, const float value) {
  *value_inv = __frcp_rn(value);
}

1008
1009
1010
1011
1012
////////////////////////////////////////////////////////////////////////////////////////////////////

using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using e8m0_t = uint8_t;
yuguo's avatar
yuguo committed
1013
using int8 = int8_t;
1014

1015
enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENSIONAL = 2 };
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031

template <typename T>
struct Numeric_Traits;

template <>
struct Numeric_Traits<fp8e4m3> {
  static constexpr int maxUnbiasedExponent = 8;
  static constexpr double maxNorm = 448;
};

template <>
struct Numeric_Traits<fp8e5m2> {
  static constexpr int maxUnbiasedExponent = 15;
  static constexpr double maxNorm = 57344;
};

yuguo's avatar
yuguo committed
1032
1033
1034
1035
1036
1037
template <>
struct Numeric_Traits<int8> {
  static constexpr int maxUnbiasedExponent = 0;
  static constexpr double maxNorm = 127;
};

1038
1039
1040
1041
1042
1043
1044
1045
1046
template <typename T>
struct Quantized_Limits {
  static constexpr int max_unbiased_exponent = Numeric_Traits<T>::maxUnbiasedExponent;
  static constexpr float max_norm = Numeric_Traits<T>::maxNorm;
  static constexpr float max_norm_rcp = 1.0 / max_norm;
  static constexpr float emax = 1 << max_unbiased_exponent;
  static constexpr float emax_rcp = 1.0 / emax;
};

Przemek Tredak's avatar
Przemek Tredak committed
1047
1048
1049
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_