utils.cuh 31.6 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
#define TRANSFORMER_ENGINE_COMMON_UTILS_CUH_

#include <cuda_fp16.h>
yuguo's avatar
yuguo committed
11
12
13
14
15

#ifdef __HIP_PLATFORM_AMD__
#ifndef __HIPCC_RTC__
#include <cstdint>
#else
16
#include <hip/amd_detail/hip_assert.h>
yuguo's avatar
yuguo committed
17
18
19
20
21
using namespace __hip_internal;
#endif
#endif

#include <cuda_bf16.h>
Tim Moon's avatar
Tim Moon committed
22
#include <cuda_fp8.h>
Przemek Tredak's avatar
Przemek Tredak committed
23

yuguo's avatar
yuguo committed
24
25
26
#ifdef __HIP_PLATFORM_AMD__
typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2)));
#else
Tim Moon's avatar
Tim Moon committed
27
28
29
30
31
32
33
34
35
36
37
38
39
#if !defined(__CUDACC_RTC__)
#include <cstdint>
#else
// Importing C++ standard headers is a pain with NVRTC
using uint8_t = unsigned char;
using uint16_t = unsigned short int;  // NOLINT(*)
using uint32_t = unsigned int;
using uint64_t = unsigned long long int;  // NOLINT(*)
static_assert(sizeof(uint8_t) == 1);
static_assert(sizeof(uint16_t) == 2);
static_assert(sizeof(uint32_t) == 4);
static_assert(sizeof(uint64_t) == 8);
#endif
Przemek Tredak's avatar
Przemek Tredak committed
40

yuguo's avatar
yuguo committed
41
#endif // __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
42
43
44
45
46
47
////////////////////////////////////////////////////////////////////////////////////////////////////

constexpr uint32_t THREADS_PER_WARP = 32;

////////////////////////////////////////////////////////////////////////////////////////////////////

yuguo's avatar
yuguo committed
48
#if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__)
49
50
inline __device__ float2 operator+(const float2 &a, const float2 &b) {  // NOLINT(*)
  return {a.x + b.x, a.y + b.y};
Przemek Tredak's avatar
Przemek Tredak committed
51
52
53
54
}

////////////////////////////////////////////////////////////////////////////////////////////////////

55
56
57
inline __device__ void operator+=(float2 &a, const float2 &b) {  // NOLINT(*)
  a.x += b.x;
  a.y += b.y;
Przemek Tredak's avatar
Przemek Tredak committed
58
}
yuguo's avatar
yuguo committed
59
#endif
Przemek Tredak's avatar
Przemek Tredak committed
60
61
62

////////////////////////////////////////////////////////////////////////////////////////////////////

63
template <typename T>
Przemek Tredak's avatar
Przemek Tredak committed
64
struct Sum {
65
66
  inline __device__ Sum() {}
  inline __device__ T operator()(const T &a, const T &b) const { return a + b; }
Przemek Tredak's avatar
Przemek Tredak committed
67
68
69
70
};

////////////////////////////////////////////////////////////////////////////////////////////////////

71
72
template <typename T>
inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) {
yuguo's avatar
yuguo committed
73
74
75
#ifdef __HIP_PLATFORM_AMD__
  return __shfl_xor(x, idx, THREADS_PER_WARP);
#else
76
  return __shfl_xor_sync(static_cast<uint32_t>(-1), x, idx);
yuguo's avatar
yuguo committed
77
#endif
Przemek Tredak's avatar
Przemek Tredak committed
78
79
}

80
81
82
template <>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 &x, uint32_t idx) {
  return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)};
Przemek Tredak's avatar
Przemek Tredak committed
83
84
}

85
86
template <typename T>
inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) {
yuguo's avatar
yuguo committed
87
88
89
#ifdef __HIP_PLATFORM_AMD__
  return __shfl_down(x, idx, THREADS_PER_WARP);
#else
90
  return __shfl_down_sync(static_cast<uint32_t>(-1), x, idx);
yuguo's avatar
yuguo committed
91
#endif
Przemek Tredak's avatar
Przemek Tredak committed
92
93
}

94
95
96
template <>
inline __device__ float2 warp_shuffle_down<float2>(const float2 &x, uint32_t idx) {
  return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)};
Przemek Tredak's avatar
Przemek Tredak committed
97
98
99
100
101
102
103
104
105
}

////////////////////////////////////////////////////////////////////////////////////////////////////

namespace transformer_engine {

////////////////////////////////////////////////////////////////////////////////////////////////////

struct uint16 {
106
107
108
109
  uint4 u;
  uint4 v;
  uint4 s;
  uint4 t;
Przemek Tredak's avatar
Przemek Tredak committed
110
111
112
113
114
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct uint8 {
115
116
  uint4 u;
  uint4 v;
Przemek Tredak's avatar
Przemek Tredak committed
117
118
119
120
};

////////////////////////////////////////////////////////////////////////////////////////////////////

121
template <int BYTES>
Przemek Tredak's avatar
Przemek Tredak committed
122
123
struct BytesToType {};

124
template <>
Przemek Tredak's avatar
Przemek Tredak committed
125
struct BytesToType<64> {
126
127
  using Type = uint16;
  static_assert(sizeof(Type) == 64);
Przemek Tredak's avatar
Przemek Tredak committed
128
129
};

130
template <>
Przemek Tredak's avatar
Przemek Tredak committed
131
struct BytesToType<32> {
132
133
  using Type = uint8;
  static_assert(sizeof(Type) == 32);
Przemek Tredak's avatar
Przemek Tredak committed
134
135
};

136
template <>
Przemek Tredak's avatar
Przemek Tredak committed
137
struct BytesToType<16> {
138
139
  using Type = uint4;
  static_assert(sizeof(Type) == 16);
Przemek Tredak's avatar
Przemek Tredak committed
140
141
};

142
template <>
Przemek Tredak's avatar
Przemek Tredak committed
143
struct BytesToType<8> {
144
145
  using Type = uint64_t;
  static_assert(sizeof(Type) == 8);
Przemek Tredak's avatar
Przemek Tredak committed
146
147
};

148
template <>
Przemek Tredak's avatar
Przemek Tredak committed
149
struct BytesToType<4> {
150
151
  using Type = uint32_t;
  static_assert(sizeof(Type) == 4);
Przemek Tredak's avatar
Przemek Tredak committed
152
153
};

154
template <>
Przemek Tredak's avatar
Przemek Tredak committed
155
struct BytesToType<2> {
156
157
  using Type = uint16_t;
  static_assert(sizeof(Type) == 2);
Przemek Tredak's avatar
Przemek Tredak committed
158
159
};

160
template <>
Przemek Tredak's avatar
Przemek Tredak committed
161
struct BytesToType<1> {
162
163
  using Type = uint8_t;
  static_assert(sizeof(Type) == 1);
Przemek Tredak's avatar
Przemek Tredak committed
164
165
166
167
};

////////////////////////////////////////////////////////////////////////////////////////////////////

168
template <typename T>
Przemek Tredak's avatar
Przemek Tredak committed
169
170
struct TypeToVec2 {};

171
template <>
Przemek Tredak's avatar
Przemek Tredak committed
172
struct TypeToVec2<float> {
173
  using Type = float2;
Przemek Tredak's avatar
Przemek Tredak committed
174
175
};

176
template <>
Przemek Tredak's avatar
Przemek Tredak committed
177
struct TypeToVec2<half> {
178
  using Type = half2;
Przemek Tredak's avatar
Przemek Tredak committed
179
180
};

yuguo's avatar
yuguo committed
181
182
183
184
185
186
#ifdef __HIP_PLATFORM_AMD__
template <>
struct TypeToVec2<__hip_bfloat16> {
  using Type = hip_bfloat16x2;
};
#else
187
template <>
Przemek Tredak's avatar
Przemek Tredak committed
188
struct TypeToVec2<nv_bfloat16> {
189
  using Type = nv_bfloat162;
Przemek Tredak's avatar
Przemek Tredak committed
190
};
yuguo's avatar
yuguo committed
191
#endif
Przemek Tredak's avatar
Przemek Tredak committed
192
193
194

////////////////////////////////////////////////////////////////////////////////////////////////////

195
196
template <typename IType, typename IType2, typename OType, typename CType>
struct CTDBiasDActParam {
197
198
199
200
201
202
203
204
205
206
207
208
209
  using InputType = IType;
  using InputType2 = IType2;
  using OutputType = OType;
  using ComputeType = CType;
  const IType *input;
  const IType2 *act_input;
  OType *output_c;
  OType *output_t;
  const CType *scale_ptr;
  CType *amax;
  CType *scale_inv;
  CType *workspace;
  CType *warp_scales_inv;
210
211
212
213
};

////////////////////////////////////////////////////////////////////////////////////////////////////

214
template <int INDEX>
Przemek Tredak's avatar
Przemek Tredak committed
215
struct Get {
216
217
  template <typename T, typename R>
  static inline __device__ R of(const T &vec);
Przemek Tredak's avatar
Przemek Tredak committed
218
219
};

220
221
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
222
inline __device__ R Get<0>::of(const T &vec) {
223
  return vec.x;
Przemek Tredak's avatar
Przemek Tredak committed
224
225
}

226
227
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
228
inline __device__ R Get<1>::of(const T &vec) {
229
  return vec.y;
Przemek Tredak's avatar
Przemek Tredak committed
230
231
}

232
233
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
234
inline __device__ R Get<2>::of(const T &vec) {
235
  return vec.z;
Przemek Tredak's avatar
Przemek Tredak committed
236
237
}

238
239
template <>
template <typename T, typename R>
Przemek Tredak's avatar
Przemek Tredak committed
240
inline __device__ R Get<3>::of(const T &vec) {
241
  return vec.w;
Przemek Tredak's avatar
Przemek Tredak committed
242
243
244
245
}

////////////////////////////////////////////////////////////////////////////////////////////////////

246
247
248
template <typename Src, typename Dst>
struct Converter {
  static inline __device__ Dst convert(const Src &from) { return Dst(from); }
Przemek Tredak's avatar
Przemek Tredak committed
249
250
};

251
252
253
template <>
struct Converter<float2, half2> {
  static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); }
Przemek Tredak's avatar
Przemek Tredak committed
254
255
};

yuguo's avatar
yuguo committed
256
257
258
259
260
261
#ifdef __HIP_PLATFORM_AMD__
template <>
struct Converter<float2, hip_bfloat16x2> {
  static inline __device__ hip_bfloat16x2 convert(const float2 &x) {
    union {
      hip_bfloat16x2 raw;
yuguo's avatar
yuguo committed
262
      __hip_bfloat16 elt[2];
yuguo's avatar
yuguo committed
263
264
265
266
267
268
269
    } tmp;
    tmp.elt[0] = __hip_bfloat16(x.x);
    tmp.elt[1] = __hip_bfloat16(x.y);
    return tmp.raw;
  }
};
#else
270
271
272
template <>
struct Converter<float2, nv_bfloat162> {
  static inline __device__ nv_bfloat162 convert(const float2 &x) {
Przemek Tredak's avatar
Przemek Tredak committed
273
#if __CUDA_ARCH__ >= 800
274
    return __float22bfloat162_rn(x);
Przemek Tredak's avatar
Przemek Tredak committed
275
#else
276
277
278
279
280
281
282
    union {
      nv_bfloat162 raw;
      nv_bfloat16 elt[2];
    } tmp;
    tmp.elt[0] = __float2bfloat16_rn(x.x);
    tmp.elt[1] = __float2bfloat16_rn(x.y);
    return tmp.raw;
Przemek Tredak's avatar
Przemek Tredak committed
283
#endif
284
  }
Przemek Tredak's avatar
Przemek Tredak committed
285
};
yuguo's avatar
yuguo committed
286
#endif
Przemek Tredak's avatar
Przemek Tredak committed
287
288
289

////////////////////////////////////////////////////////////////////////////////////////////////////

290
291
292
template <typename T>
struct Zeros {
  static inline __device__ T get() { return T(0.f); }
Przemek Tredak's avatar
Przemek Tredak committed
293
294
};

295
296
297
template <>
struct Zeros<float2> {
  static inline __device__ float2 get() { return make_float2(0.f, 0.f); }
Przemek Tredak's avatar
Przemek Tredak committed
298
299
300
301
};

////////////////////////////////////////////////////////////////////////////////////////////////////

302
template <typename Elt_type, uint32_t NUM_ELT>
Przemek Tredak's avatar
Przemek Tredak committed
303
struct Vec {
304
  enum { BYTES = NUM_ELT * sizeof(Elt_type) };
Przemek Tredak's avatar
Przemek Tredak committed
305

306
307
  using Vec_type = typename BytesToType<BYTES>::Type;
  using type = Elt_type;
Przemek Tredak's avatar
Przemek Tredak committed
308

309
310
311
312
  using Alias_type = union {
    Vec_type vec;
    Elt_type elt[NUM_ELT];
  };
Przemek Tredak's avatar
Przemek Tredak committed
313

314
  Alias_type data;
yuguo's avatar
yuguo committed
315
316
317
318
319
320
#ifdef __HIP_PLATFORM_AMD__
  __HOST_DEVICE__ Vec& operator=(const Vec& rhs) {
    data.vec = rhs.data.vec;
    return *this;
  }
#endif
Przemek Tredak's avatar
Przemek Tredak committed
321

322
323
324
325
326
  template <typename S>
  inline __device__ void to(Vec<S, NUM_ELT> &other) {  // NOLINT(*)
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      other.data.elt[it] = S(this->data.elt[it]);
Przemek Tredak's avatar
Przemek Tredak committed
327
    }
328
  }
Przemek Tredak's avatar
Przemek Tredak committed
329

330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
  template <typename Op>
  inline __device__ void assign(const Op &op) {
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      this->data.elt[it] = op(it);
    }
  }

  // Pointer is cast to vector type
  inline __device__ void load_from(const void *base_ptr, size_t idx = 0) {
    this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
  }

  // Pointer is cast to vector type
  inline __device__ void store_to(void *base_ptr, size_t idx = 0) const {
    static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
  }

  // Pointer is cast to element type. Loads min(count, NUM_ELT)
  // elements and any remaining elements are set to zero.
  inline __device__ void load_from_elts(const void *base_ptr, size_t idx = 0,
                                        size_t count = NUM_ELT) {
    const Elt_type *elt_ptr = static_cast<const Elt_type *>(base_ptr) + idx;
    if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
      for (int it = 0; it < NUM_ELT; it++) {
        this->data.elt[it] = (it < count ? elt_ptr[it] : Elt_type(0.f));
      }
    } else {
      this->load_from(elt_ptr);
    }
  }

  // Pointer is cast to element type. Stores min(count, NUM_ELT)
  // elements.
  inline __device__ void store_to_elts(void *base_ptr, size_t idx = 0,
                                       size_t count = NUM_ELT) const {
    Elt_type *elt_ptr = static_cast<Elt_type *>(base_ptr) + idx;
    if (count < NUM_ELT || reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0) {
#pragma unroll
      for (int it = 0; it < NUM_ELT; it++) {
        if (it < count) {
          elt_ptr[it] = this->data.elt[it];
Przemek Tredak's avatar
Przemek Tredak committed
373
        }
374
375
376
      }
    } else {
      this->store_to(elt_ptr);
Przemek Tredak's avatar
Przemek Tredak committed
377
    }
378
  }
Przemek Tredak's avatar
Przemek Tredak committed
379

380
381
382
383
  inline __device__ void clear() {
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      this->data.elt[it] = Elt_type(0.f);
Przemek Tredak's avatar
Przemek Tredak committed
384
    }
385
  }
Przemek Tredak's avatar
Przemek Tredak committed
386
387
388
389
390
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct InterCTASync {
391
392
393
394
395
396
397
398
399
400
  inline __device__ InterCTASync(int *barrier, int group, int num_groups, int group_size)
      : phase_counter_(0),
        b0_(barrier + group)  // The barrier for this group of CTAs.
        ,
        b1_(barrier + group + num_groups)  // The barrier for this group of CTAs.
        ,
        group_size_(group_size) {
    // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
  }

yuguo's avatar
yuguo committed
401
402
403
404
405
406
407
408
#ifdef __HIP_PLATFORM_AMD__
  inline __device__ void spin_wait_(int *barrier, int step, int expected) {
    __hip_atomic_fetch_add(barrier, step, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
    for (int found = -1; found != expected; ) {
      found = __hip_atomic_load(barrier, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT);
    }
  }
#else
409
410
411
412
413
414
  inline __device__ void spin_wait_(int *barrier, int step, int expected) {
    asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
    for (int found = -1; found != expected;) {
      asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
    }
  }
yuguo's avatar
yuguo committed
415
#endif
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434

  inline __device__ void sync() {
    // ALL THREADS MUST ENTER!

    // We switch barrier every iteration.
    int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
    // We decrement every other iteration.
    bool dec = phase_counter_ & 0x2;
    int step = dec ? -1 : 1;
    int expected = dec ? 0 : group_size_;
    // There are only 4 phases: up/down for b0/b1.
    phase_counter_ = (phase_counter_ + 1) & 0x3;

    if (threadIdx.x == 0) {
      spin_wait_(barrier, step, expected);
    }
    // CTA waits for thread 0
    __syncthreads();
  }
Przemek Tredak's avatar
Przemek Tredak committed
435

436
437
438
439
  int phase_counter_;
  int *b0_;
  int *b1_;
  int group_size_;
Przemek Tredak's avatar
Przemek Tredak committed
440
441
442
443
};

////////////////////////////////////////////////////////////////////////////////////////////////////

444
template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
445
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
  using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
  using Type = typename Base::Type;

  enum { SMEM_BYTES = Base::SMEM_BYTES };

  enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
  enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };

  // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
  enum {
    WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES
  };

  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {}

  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    data = Base::reduce(data, op);
    // We switch workspace every iteration.
    T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;

    // Warp leaders 0 hold the CTA-local results.
    if (this->warp_n_ == 0 && this->lane_ == 0) {
      workspace[bidn_] = data;
    }
    inter_cta_.sync();
    static_assert(CTAS_PER_ROW <= 32);
    T total = Zeros<T>::get();
    if (this->lane_ < CTAS_PER_ROW) {
      total = workspace[this->lane_];
    }
    total = Reducer<T, 1, 1, 1>::allreduce_(total, op);

    return total;
  }

  InterCTASync inter_cta_;

  T *const w0_;
  T *const w1_;
  int bidn_;
Przemek Tredak's avatar
Przemek Tredak committed
495
496
497
498
};

////////////////////////////////////////////////////////////////////////////////////////////////////

499
template <typename T, uint32_t WARPS_M>
Przemek Tredak's avatar
Przemek Tredak committed
500
struct Reducer<T, 1, WARPS_M, 1> {
501
502
503
  using Type = T;
  enum { SMEM_BYTES = 0 };
  enum { WORKSPACE_BYTES_PER_GROUP = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
504

505
506
507
508
509
510
511
512
513
514
515
516
  enum { THREADS_PER_WARP = 32 };

  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : warp_n_(warp_n), lane_(lane) {}

  template <typename Op>
  static inline __device__ T allreduce_(T data, const Op &op) {
#pragma unroll
    for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
      data = op(data, warp_shuffle_xor(data, it));
Przemek Tredak's avatar
Przemek Tredak committed
517
    }
518
519
    return data;
  }
Przemek Tredak's avatar
Przemek Tredak committed
520

521
522
523
524
525
526
527
528
529
530
531
  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    return allreduce_(data, op);
  }

  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
// only lane 0 holds the result!
#pragma unroll
    for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) {
      data = op(data, warp_shuffle_down(data, it));
Przemek Tredak's avatar
Przemek Tredak committed
532
    }
533
534
535
536
    return data;
  }
  int warp_n_;
  int lane_;
Przemek Tredak's avatar
Przemek Tredak committed
537
538
539
540
};

////////////////////////////////////////////////////////////////////////////////////////////////////

541
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
542
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
543
  using Base = Reducer<T, 1, WARPS_M, 1>;
Przemek Tredak's avatar
Przemek Tredak committed
544

545
  using Type = T;
Przemek Tredak's avatar
Przemek Tredak committed
546

547
548
  enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
  enum { WORKSPACE_BYTES_PER_GROUP = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
549

550
  enum { THREADS_PER_WARP = 32 };
Przemek Tredak's avatar
Przemek Tredak committed
551

552
553
554
555
556
557
558
  template <typename Params>
  inline __device__ Reducer(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                            uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        use0_(true),
        smem0_(&(static_cast<T *>(smem)[warp_m * WARPS_N])),
        smem1_(smem0_ + WARPS_M * WARPS_N) {}
Przemek Tredak's avatar
Przemek Tredak committed
559

560
561
562
563
564
565
566
  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    T *const smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    data = Base::reduce(data, op);
    if (this->lane_ == 0) {
      smem[this->warp_n_] = data;
Przemek Tredak's avatar
Przemek Tredak committed
567
    }
568
569
570
571
572
573
574
575
    __syncthreads();
    T out = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < WARPS_N; it++) {
      out = op(out, smem[it]);
    }
    return out;
  }
Przemek Tredak's avatar
Przemek Tredak committed
576

577
578
579
580
581
582
583
584
  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
    T *const smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    // only intra-CTA group leader holds the result!
    data = Base::reduce(data, op);
    if (this->lane_ == 0) {
      smem[this->warp_n_] = data;
Przemek Tredak's avatar
Przemek Tredak committed
585
    }
586
587
588
589
590
591
592
593
594
595
    __syncthreads();
    T out = Zeros<T>::get();
    if (this->warp_n_ == 0 && this->lane_ == 0) {
#pragma unroll
      for (int it = 0; it < WARPS_N; it++) {
        out = op(out, smem[it]);
      }
    }
    return out;
  }
Przemek Tredak's avatar
Przemek Tredak committed
596

597
598
599
  T *const smem0_;
  T *const smem1_;
  bool use0_;
Przemek Tredak's avatar
Przemek Tredak committed
600
601
602
603
};

////////////////////////////////////////////////////////////////////////////////////////////////////

604
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
605
struct DynamicReducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
  using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
  using Type = typename Base::Type;

  template <typename Params>
  inline __device__ DynamicReducer(const Params &params, uint32_t bidm, uint32_t bidn,
                                   uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        inter_cta_(params.barrier, bidm, params.ctas_per_col, params.ctas_per_row),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * params.ctas_per_row),
        w1_(w0_ + params.ctas_per_col * WARPS_M * params.ctas_per_row) {}

  template <typename Op>
  inline __device__ T allreduce(T data, const Op &op) {
    // Trivial case
    if (inter_cta_.group_size_ == 1) {
      return Base::allreduce(data, op);
    }

    data = Base::reduce(data, op);
    // We switch workspace every iteration.
    T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;

    // Warp leaders 0 hold the CTA-local results.
    if (this->warp_n_ == 0 && this->lane_ == 0) {
      workspace[bidn_] = data;
    }
    inter_cta_.sync();
    T total = Zeros<T>::get();
    for (int it = this->lane_; it < inter_cta_.group_size_; it += THREADS_PER_WARP) {
      total = op(total, workspace[it]);
    }
    total = Reducer<T, 1, 1, 1>::allreduce_(total, op);

    return total;
  }

  template <typename Op>
  inline __device__ T reduce(T data, const Op &op) {
    return allreduce(data, op);
  }

  InterCTASync inter_cta_;

  T *const w0_;
  T *const w1_;
  int bidn_;
Przemek Tredak's avatar
Przemek Tredak committed
654
655
656
657
};

////////////////////////////////////////////////////////////////////////////////////////////////////

658
659
660
661
662
663
664
665
666
667
668
669
670
/*
This is an implementation of the parallel Welford algorithm for incrementally computing variance

This algorithm is known as Chan's update formulae (Chat et al '79):
http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf

An introduction is provided by Wikipedia here:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance?section=5#Parallel_algorithm

A detailed reference on the exact version implemented (with better numerical stability) is provided here:
https://dbs.ifi.uni-heidelberg.de/files/Team/eschubert/publications/SSDBM18-covariance-authorcopy.pdf
*/

671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
template <typename T>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a,
                                             int num_active) {  // NOLINT(*)
  // Assume at least leftmost is valid and
  // init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
  int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);

#pragma unroll
  for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) {
    // Exchange
    T n_b = warp_shuffle_down(n_a, step);
    T m_b = warp_shuffle_down(m_a, step);
    T m2_b = warp_shuffle_down(m2_a, step);

    // Update
    const T n_ab = n_a + n_b;  // We can handle one of them being 0, not both.
    // Might have different n per thread, otherwise this would simplify :(
    const T rn_ab = 1.f / n_ab;
    const T delta = m_a - m_b;
    const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
    const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;

    n_a = n_ab;
    m_a = m_ab;
    m2_a = m2_ab;
  }
  // Intra-warp broadcast (only lane 0 has valid stats).
yuguo's avatar
yuguo committed
698
699
700
701
#ifdef __HIP_PLATFORM_AMD__
  m_a = __shfl(m_a, 0, THREADS_PER_WARP);
  m2_a = __shfl(m2_a, 0, THREADS_PER_WARP);
#else
702
703
  m_a = __shfl_sync(static_cast<uint32_t>(-1), m_a, 0);
  m2_a = __shfl_sync(static_cast<uint32_t>(-1), m2_a, 0);
yuguo's avatar
yuguo committed
704
#endif
Przemek Tredak's avatar
Przemek Tredak committed
705
706
707
708
}

////////////////////////////////////////////////////////////////////////////////////////////////////

709
template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
710
struct Stats {
711
712
  // This could be done generically with the Reducer. But then we
  // would have to exchange 3 instead of 2 fields.
Przemek Tredak's avatar
Przemek Tredak committed
713

714
715
  using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
  using stats_t = typename BlockStats::stats_t;
Przemek Tredak's avatar
Przemek Tredak committed
716

717
  enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
Przemek Tredak's avatar
Przemek Tredak committed
718

719
720
721
722
723
724
725
726
727
728
729
  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW),
        block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<stats_t *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW),
        warp_n_(warp_n),
        lane_(lane) {}
Przemek Tredak's avatar
Przemek Tredak committed
730

731
732
733
734
735
736
  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
    // TODO(ptredak) rn is not really needed here..
    constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
    stats_t block_stats = block_stats_.compute(elts, block_rn);
Przemek Tredak's avatar
Przemek Tredak committed
737

738
    stats_t *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
Przemek Tredak's avatar
Przemek Tredak committed
739

740
741
    if (warp_n_ == 0 && lane_ == 0) {
      workspace[bidn_] = block_stats;
Przemek Tredak's avatar
Przemek Tredak committed
742
743
    }

744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
    // Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
    inter_cta_.sync();

    T n = Zeros<T>::get();
    T m = Zeros<T>::get();
    T m2 = Zeros<T>::get();

    // Assume CTA group size in N less than 32, such that we can finalize with a single warp.
    static_assert(CTAS_PER_ROW <= 32);

    // Every warp does the final reduction locally.
    if (lane_ < CTAS_PER_ROW) {
      stats_t result = workspace[lane_];
      n = ELTS_PER_ROW_PER_CTA;
      m = transformer_engine::Get<0>::of<stats_t, T>(result);
      m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
    }

    warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);

    return {m, m2};
  }
Przemek Tredak's avatar
Przemek Tredak committed
766

767
768
769
770
771
772
773
774
  InterCTASync inter_cta_;
  BlockStats block_stats_;

  stats_t *const w0_;
  stats_t *const w1_;
  int bidn_;
  int warp_n_;
  int lane_;
Przemek Tredak's avatar
Przemek Tredak committed
775
776
777
778
};

////////////////////////////////////////////////////////////////////////////////////////////////////

779
template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
Przemek Tredak's avatar
Przemek Tredak committed
780
struct Stats<T, 1, WARPS_M, WARPS_N> {
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
  using WarpStats = Stats<T, 1, WARPS_M, 1>;
  using stats_t = typename WarpStats::stats_t;

  enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };

  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) {
    smem0_ = static_cast<stats_t *>(smem) + warp_m * WARPS_N;
    smem1_ = smem0_ + WARPS_M * WARPS_N;
  }

  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    stats_t *smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    // Compute warp local for all WARPS_N
    constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
    stats_t warp_stats = warp_stats_.compute(elts, warp_rn);

    // Each warp warp leader stores its stats
    const auto warp_n = warp_stats_.reducer_.warp_n_;
    const auto lane = warp_stats_.reducer_.lane_;
    if (lane == 0) {
      smem[warp_n] = warp_stats;
Przemek Tredak's avatar
Przemek Tredak committed
807
    }
808
    __syncthreads();
Przemek Tredak's avatar
Przemek Tredak committed
809

810
811
812
    T n = Zeros<T>::get();
    T m = Zeros<T>::get();
    T m2 = Zeros<T>::get();
Przemek Tredak's avatar
Przemek Tredak committed
813

814
815
816
817
818
819
820
    // Assume that there are less than 32 warps, such that we can finalize with a single warp
    static_assert(WARPS_N <= 32);
    if (lane < WARPS_N) {
      stats_t result = smem[lane];
      n = N * THREADS_PER_WARP;
      m = transformer_engine::Get<0>::of<stats_t, T>(result);
      m2 = transformer_engine::Get<1>::of<stats_t, T>(result);
Przemek Tredak's avatar
Przemek Tredak committed
821
    }
822
823
824
825
826
827
828
829
830

    warp_chan_upd_dynamic(m, m2, n, WARPS_N);

    return {m, m2};
  }
  WarpStats warp_stats_;
  stats_t *smem0_;
  stats_t *smem1_;
  bool use0_;
Przemek Tredak's avatar
Przemek Tredak committed
831
832
833
834
};

////////////////////////////////////////////////////////////////////////////////////////////////////

835
template <typename T, uint32_t WARPS_M>
Przemek Tredak's avatar
Przemek Tredak committed
836
struct Stats<T, 1, WARPS_M, 1> {
837
838
839
  using stats_t = typename TypeToVec2<T>::Type;
  // The simple Warp reducer.
  using Reducer = Reducer<T, 1, WARPS_M, 1>;
Przemek Tredak's avatar
Przemek Tredak committed
840

841
  enum { SMEM_BYTES = 0 };
Przemek Tredak's avatar
Przemek Tredak committed
842

843
844
845
846
  template <typename Params>
  inline __device__ Stats(const Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m,
                          uint32_t warp_n, uint32_t lane, void *smem)
      : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {}
Przemek Tredak's avatar
Przemek Tredak committed
847

848
849
850
  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    auto sum = Sum<T>();
Przemek Tredak's avatar
Przemek Tredak committed
851

852
853
854
855
856
857
    T m = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < N; it++) {
      m += elts[it];
    }
    m = reducer_.allreduce(m, sum) * rn;
Przemek Tredak's avatar
Przemek Tredak committed
858

859
860
861
862
863
    T m2 = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < N; it++) {
      T diff = (elts[it] - m);
      m2 += diff * diff;
Przemek Tredak's avatar
Przemek Tredak committed
864
    }
865
866
867
868
    m2 = reducer_.allreduce(m2, sum);

    return {m, m2};
  }
Przemek Tredak's avatar
Przemek Tredak committed
869

870
  Reducer reducer_;
Przemek Tredak's avatar
Przemek Tredak committed
871
872
873
874
875
876
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int num_elems>
__device__ __forceinline__ float warp_reduce_max(const float m) {
877
  float tmp = m;
Przemek Tredak's avatar
Przemek Tredak committed
878
#pragma unroll
879
  for (int delta = num_elems / 2; delta > 0; delta /= 2) {
yuguo's avatar
yuguo committed
880
881
882
#ifdef __HIP_PLATFORM_AMD__
    const float other_m = __shfl_down(tmp, delta, THREADS_PER_WARP);
#else
883
    const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta);
yuguo's avatar
yuguo committed
884
#endif
885
886
887
888
889
    __builtin_assume(tmp >= 0);
    __builtin_assume(other_m >= 0);
    tmp = fmaxf(tmp, other_m);
  }
  return tmp;
Przemek Tredak's avatar
Przemek Tredak committed
890
891
}

892
893
894
895
__forceinline__ __device__ float warp_reduce_max_broadcast(const float val) {
  float val_tmp = val;
#pragma unroll
  for (int offset = THREADS_PER_WARP / 2; offset > 0; offset /= 2) {
yuguo's avatar
yuguo committed
896
897
898
#ifdef __HIP_PLATFORM_AMD__
    const float val_other = __shfl_down(val_tmp, offset, THREADS_PER_WARP);
#else
899
    const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset);
yuguo's avatar
yuguo committed
900
#endif
901
902
903
904
905
906
    __builtin_assume(val_tmp >= 0);
    __builtin_assume(val_other >= 0);
    val_tmp = fmaxf(val_tmp, val_other);
  }
  // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
  constexpr int subwarp_lane_zero = 0;
yuguo's avatar
yuguo committed
907
908
909
#ifdef __HIP_PLATFORM_AMD__
  val_tmp = __shfl(val_tmp, subwarp_lane_zero, THREADS_PER_WARP);
#else
910
  val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero);
yuguo's avatar
yuguo committed
911
#endif
912
913
914
  return val_tmp;
}

Przemek Tredak's avatar
Przemek Tredak committed
915
916
template <int num_warps, typename compute_t>
__device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) {
917
918
919
920
921
922
923
924
  __shared__ float staging[num_warps];
  constexpr int warp_size = 32;
  const float my_max = m;
  const float my_warp_max = warp_reduce_max<warp_size>(my_max);
  if (threadIdx.x % 32 == 0) {
    staging[warpid] = my_warp_max;
  }
  __syncthreads();
925
  compute_t result = 0.f;
926
927
928
929
930
  if (warpid == 0) {
    const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0;
    result = warp_reduce_max<num_warps>(my_max);
  }
  return result;
Przemek Tredak's avatar
Przemek Tredak committed
931
932
}

933
934
935
936
937
938
939
940
941
942
943
944
/**
 * Max reduction in subwarps
 * E.g., if nvec=4, each warp processes 128 elements (32 x 4), that covers four MXFP8 scaling factors.
 * To compute an actual scaling factor for 32 consequentive elements, only 8 threads need to participate,
 * thus splitting the warp into 4x smaller subwarps 8-thread width.
 * 'Butterfly' reduction is used inside subwarps.
 */
template <int subwarp_width>
__forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) {
  float val_tmp = val;
#pragma unroll
  for (int offset = subwarp_width / 2; offset > 0; offset /= 2) {
yuguo's avatar
yuguo committed
945
946
947
#ifdef __HIP_PLATFORM_AMD__
    const float val_other = __shfl_down(val_tmp, offset, subwarp_width);
#else
948
    const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width);
yuguo's avatar
yuguo committed
949
#endif
950
951
952
953
954
955
    __builtin_assume(val_tmp >= 0);
    __builtin_assume(val_other >= 0);
    val_tmp = fmaxf(val_tmp, val_other);
  }
  // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id
  constexpr int subwarp_lane_zero = 0;
yuguo's avatar
yuguo committed
956
957
958
#ifdef __HIP_PLATFORM_AMD__
  val_tmp = __shfl(val_tmp, subwarp_lane_zero, subwarp_width);
#else
959
  val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width);
yuguo's avatar
yuguo committed
960
#endif
961
962
963
  return val_tmp;
}

Przemek Tredak's avatar
Przemek Tredak committed
964
// Works only on positive values
965
966
__device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) {
  atomicMax(reinterpret_cast<int *>(addr), __float_as_int(value));
Przemek Tredak's avatar
Przemek Tredak committed
967
968
969
}

// Works only on positive values
970
971
__device__ __forceinline__ void atomicMinFloat(float *addr, const float value) {
  atomicMin(reinterpret_cast<int *>(addr), __float_as_int(value));
Przemek Tredak's avatar
Przemek Tredak committed
972
973
974
}

template <typename T>
975
976
__device__ __forceinline__ void reciprocal(T *value_inv, const T value) {
  *value_inv = 1 / value;
Przemek Tredak's avatar
Przemek Tredak committed
977
978
}

979
980
981
982
983
template <>
__device__ __forceinline__ void reciprocal<float>(float *value_inv, const float value) {
  *value_inv = __frcp_rn(value);
}

984
985
986
987
988
////////////////////////////////////////////////////////////////////////////////////////////////////

using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using e8m0_t = uint8_t;
yuguo's avatar
yuguo committed
989
using int8 = int8_t;
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010

constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127;

enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 };

template <typename T>
struct Numeric_Traits;

template <>
struct Numeric_Traits<fp8e4m3> {
  static constexpr int maxUnbiasedExponent = 8;
  static constexpr double maxNorm = 448;
};

template <>
struct Numeric_Traits<fp8e5m2> {
  static constexpr int maxUnbiasedExponent = 15;
  static constexpr double maxNorm = 57344;
};

yuguo's avatar
yuguo committed
1011
1012
1013
1014
1015
1016
template <>
struct Numeric_Traits<int8> {
  static constexpr int maxUnbiasedExponent = 0;
  static constexpr double maxNorm = 127;
};

1017
1018
1019
1020
1021
1022
1023
1024
1025
template <typename T>
struct Quantized_Limits {
  static constexpr int max_unbiased_exponent = Numeric_Traits<T>::maxUnbiasedExponent;
  static constexpr float max_norm = Numeric_Traits<T>::maxNorm;
  static constexpr float max_norm_rcp = 1.0 / max_norm;
  static constexpr float emax = 1 << max_unbiased_exponent;
  static constexpr float emax_rcp = 1.0 / emax;
};

yuguo's avatar
yuguo committed
1026
1027
1028
1029
1030
1031
1032
#ifdef __HIP_PLATFORM_AMD__
#define __CUDA_ARCH_HAS_FEATURE__(FEATURE) \
    ((__CUDA_ARCH__ >= 100 && FEATURE == SM100_ALL) || \
     (__CUDA_ARCH__ >= 101 && FEATURE == SM101_ALL) || \
     (__CUDA_ARCH__ >= 120 && FEATURE == SM120_ALL))
#endif

1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
  // TODO: nan/inf needs to be set for any value
  // of nan/inf in input not just amax.
  if (isnan(val)) {
    return 0xFF;
  }
  if (isinf(val)) {
    return 0xFE;
  }
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
     (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
  uint16_t out;
  asm volatile(
      "{\n"
      "cvt.rp.satfinite.ue8m0x2.f32  %0, 0.0, %1;\n"
      "}"
      : "=h"(out)
      : "f"(val));
  return *reinterpret_cast<e8m0_t *>(&out);
#else
  if (val == 0.0f) {
    return 0x00;
  }
  uint32_t val_u32 = *reinterpret_cast<uint32_t *>(&val);
  e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
  uint32_t mantissa = val_u32 & 0x7FFFFF;
  // Round up exponent and deal with satfinite.
  if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
    ++exponent;
  }
  return exponent;
#endif
}

__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
  return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(biased_exp));
}

Przemek Tredak's avatar
Przemek Tredak committed
1071
1072
1073
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_