common.h 33.8 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_
9
#include "util/system.h"
yuguo's avatar
yuguo committed
10
#ifndef __HIP_PLATFORM_AMD__
11
#include <cudaTypedefs.h>
yuguo's avatar
yuguo committed
12
#endif
13

14
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
15
16
17
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
18
19
20
21
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif

22
23
24
#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h>

25
#include <cstdint>
Przemek Tredak's avatar
Przemek Tredak committed
26
27
28
29
#include <functional>
#include <stdexcept>
#include <string>
#include <tuple>
Tim Moon's avatar
Tim Moon committed
30
31
#include <type_traits>
#include <unordered_map>
Przemek Tredak's avatar
Przemek Tredak committed
32
#include <vector>
Tim Moon's avatar
Tim Moon committed
33
34

#include "./nvtx.h"
35
#include "./util/cuda_driver.h"
Tim Moon's avatar
Tim Moon committed
36
#include "./util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
37
38
39

namespace transformer_engine {

40
41
42
std::string to_string(const DType type);
std::string to_string(const NVTEScalingMode &mode);

43
44
45
46
inline int blockwise_fp8_block_len() {
  return ::transformer_engine::getenv<int>("NVTE_BLOCKWISE_FP8_BLOCK_LEN", 128);
}

47
48
49
50
51
52
53
54
55
56
inline bool is_tensor_scaling(const NVTEScalingMode &mode) {
  return mode == NVTE_DELAYED_TENSOR_SCALING;
}

inline bool is_block_scaling(const NVTEScalingMode &mode) { return !is_tensor_scaling(mode); }

inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) {
  return mode == NVTE_DELAYED_TENSOR_SCALING;
}

57
58
59
60
inline bool is_nvfp4_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }

inline bool is_mxfp8_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }

61
62
inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }

63
64
inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
inline size_t product(const std::vector<size_t> &shape, const size_t begin, const size_t end) {
  NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ",
             end, " in a vector with ", shape.size(), " entries");
  size_t ret = 1;
  for (size_t i = begin; i < end; ++i) {
    ret *= shape[i];
  }
  return ret;
}

inline size_t product(const std::vector<size_t> &shape) {
  size_t ret = 1;
  for (const auto &elem : shape) {
    ret *= elem;
  }
  return ret;
}

83
84
85
86
87
struct SimpleTensor {
  void *dptr;
  std::vector<size_t> shape;
  DType dtype;

88
89
  SimpleTensor(void *dptr, const std::vector<size_t> &shape, DType dtype)
      : dptr(dptr), shape(shape), dtype(dtype) {}
90
91
92
93
94
95

  SimpleTensor(const NVTEBasicTensor &tensor)  // NOLINT
      : dptr(tensor.data_ptr),
        shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim),
        dtype(static_cast<DType>(tensor.dtype)) {}

96
  SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {}
97
98

  operator NVTEBasicTensor() const {
99
100
    return {dptr, static_cast<NVTEDType>(dtype),
            nvte_make_shape(this->shape.data(), this->shape.size())};
101
102
  }

103
  size_t numel() const {
104
105
106
107
108
109
    size_t acc = 1;
    for (const auto &dim : shape) {
      acc *= dim;
    }
    return acc;
  }
110
111
112
113
114
115

  void clear() {
    dptr = nullptr;
    shape.resize(0);
    dtype = DType::kFloat32;
  }
116
};
Przemek Tredak's avatar
Przemek Tredak committed
117

118
struct Tensor {
119
 public:
120
  SimpleTensor data;
121
  SimpleTensor columnwise_data;
122
  SimpleTensor amax;
123
  SimpleTensor columnwise_amax;
124
125
  SimpleTensor scale;
  SimpleTensor scale_inv;
126
127
128
  SimpleTensor columnwise_scale_inv;

  NVTEScalingMode scaling_mode;
129
  NVTETensor nvte_tensor;
130

131
132
  Tensor()
      : data(),
133
        columnwise_data(),
134
        amax(nullptr, {1}, DType::kFloat32),
135
        columnwise_amax(nullptr, {1}, DType::kFloat32),
136
        scale(nullptr, {1}, DType::kFloat32),
137
138
        scale_inv(nullptr, {1}, DType::kFloat32),
        columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
139
140
141
142
143
144
145
        scaling_mode(NVTE_DELAYED_TENSOR_SCALING),
        nvte_tensor(0) {}

  void clear() {
    data.clear();
    columnwise_data.clear();
    amax.clear();
146
    columnwise_amax.clear();
147
148
149
150
151
152
153
    scale.clear();
    scale_inv.clear();
    columnwise_scale_inv.clear();
    scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
  }

  explicit operator NVTETensor() const noexcept { return nvte_tensor; }
154

155
  size_t numel() const {
156
    size_t acc = 1;
157
    for (const auto dim : shape()) {
158
159
160
161
162
163
164
      acc *= dim;
    }
    return acc;
  }

  bool has_data() const noexcept { return data.dptr != nullptr; }

165
166
167
168
  // Check for size (not just pointer) for 0-dim or no token cases.
  bool has_columnwise_data() const noexcept {
    return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0;
  }
169
170
171
172
173
174
175
176

  DType dtype() const {
    if (has_data()) return data.dtype;
    if (has_columnwise_data()) return columnwise_data.dtype;
    // Fallback, used e.g. in workspace
    return data.dtype;
  }

177
178
179
180
181
182
183
184
  size_t dim() const {
    if (!has_data() && has_columnwise_data()) {
      return columnwise_data.shape.size();
    } else {
      return data.shape.size();
    }
  }

185
186
187
188
189
190
191
  std::vector<size_t> shape() const {
    /* Note: We sometimes experience spurious compiler errors
     * (-Wstringop-overflow) from this function. It appears that GCC
     * has some bugs with std::vector (see
     * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569).
     */
    switch (scaling_mode) {
192
      case NVTE_NVFP4_1D_SCALING:
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
      case NVTE_DELAYED_TENSOR_SCALING:
        if (!has_data() && has_columnwise_data()) {
          std::vector<size_t> ret;
          if (!columnwise_data.shape.empty()) {
            for (size_t i = 1; i < columnwise_data.shape.size(); i++) {
              ret.push_back(columnwise_data.shape[i]);
            }
            ret.push_back(columnwise_data.shape.front());
          }
          return ret;
        } else {
          return data.shape;
        }
        break;
      case NVTE_MXFP8_1D_SCALING:
        if (!has_data() && has_columnwise_data()) {
          return columnwise_data.shape;
        } else {
          return data.shape;
        }
        break;
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
      case NVTE_BLOCK_SCALING_1D:
      case NVTE_BLOCK_SCALING_2D: {
        if (!has_data() && has_columnwise_data()) {
          std::vector<size_t> shape;
          size_t ndim = columnwise_data.shape.size();
          shape.reserve(ndim);
          for (size_t i = 0; i + 1 < ndim; ++i) {
            shape.push_back(columnwise_data.shape[i + 1]);
          }
          if (ndim > 0) {
            shape.push_back(columnwise_data.shape[0]);
          }
          return shape;
        } else {
          // NOTE: We may have removed the data pointer from
          // data by setting usage. In that case, we return
          // the non-null shape. It is our best guess at the most
          // recent shape.
          return data.shape;
        }
        break;
      }
236
237
238
239
240
241
      default:
        NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\"");
        return {};
    }
  }

242
243
244
245
246
247
  /*! Matrix height after tensor is flattened to 2D
   *
   * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
   * as a (D1*D2*...*D(n-1), Dn) matrix.
   */
  size_t flat_first_dim() const {
248
249
250
251
252
    const auto &full_shape = shape();
    size_t ret = 1;
    if (!full_shape.empty()) {
      for (size_t i = 0; i < full_shape.size() - 1; i++) {
        ret *= full_shape[i];
253
254
      }
    }
255
    return ret;
256
257
258
259
260
261
262
263
  }

  /*! Matrix width after tensor is flattened to 2D
   *
   * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
   * as a (D1*D2*...*D(n-1), Dn) matrix.
   */
  size_t flat_last_dim() const {
264
265
266
267
268
    const auto &full_shape = shape();
    if (full_shape.empty()) {
      return 1;
    } else {
      return full_shape.back();
269
270
    }
  }
Przemek Tredak's avatar
Przemek Tredak committed
271
272
};

273
274
275
struct QuantizationConfig {
  bool force_pow_2_scales = false;
  float amax_epsilon = 0.0f;
276
  NVTETensor noop_tensor = nullptr;
277
278
  Float8BlockScaleTensorFormat float8_block_scale_tensor_format =
      Float8BlockScaleTensorFormat::GEMM_READY;
279
280
281
  NVTETensor rng_state = nullptr;
  bool nvfp4_2d_quantization = false;
  bool stochastic_rounding = false;
282
283

  static constexpr size_t attr_sizes[] = {
284
285
286
287
288
289
290
      sizeof(bool),                          // force_pow_2_scales
      sizeof(float),                         // amax_epsilon
      sizeof(NVTETensor),                    // noop_tensor
      sizeof(Float8BlockScaleTensorFormat),  // float8_block_scale_tensor_format
      sizeof(NVTETensor),                    // rng_seed and offset
      sizeof(bool),                          // nvfp4_2d_quantization
      sizeof(bool)                           // stochastic_rounding
291
292
293
  };
};

294
295
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t);

Przemek Tredak's avatar
Przemek Tredak committed
296
297
template <typename T>
constexpr T DIVUP(const T &x, const T &y) {
298
  return (((x) + ((y)-1)) / (y));
Przemek Tredak's avatar
Przemek Tredak committed
299
300
}

301
302
303
304
305
306
307
template <typename T1, typename T2>
constexpr __device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(const T1 &N, const T2 &M) {
  static_assert(std::is_integral<T1>::value && std::is_integral<T2>::value,
                "Integral type required.");
  return DIVUP(static_cast<uint64_t>(N), static_cast<uint64_t>(M)) * M;
}

Przemek Tredak's avatar
Przemek Tredak committed
308
using byte = uint8_t;
309
using int16 = int16_t;
Przemek Tredak's avatar
Przemek Tredak committed
310
using int32 = int32_t;
311
using int64 = int64_t;
Przemek Tredak's avatar
Przemek Tredak committed
312
313
using fp32 = float;
using fp16 = half;
yuguo's avatar
yuguo committed
314
using int8 = int8_t;
Przemek Tredak's avatar
Przemek Tredak committed
315
316
317
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
318
319
320
#if CUDA_VERSION >= 12080
using fp8e8m0 = __nv_fp8_e8m0;
#endif
321
322
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
323
324
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
325
#endif
326
using e8m0_t = uint8_t;
Przemek Tredak's avatar
Przemek Tredak committed
327

Tim Moon's avatar
Tim Moon committed
328
329
330
331
namespace detail {

template <typename T>
constexpr inline const char *type_name() noexcept;
332
333
334
335
336
#define TRANSFORMER_ENGINE_TYPE_NAME(T)                  \
  template <>                                            \
  inline constexpr const char *type_name<T>() noexcept { \
    return #T;                                           \
  }
Tim Moon's avatar
Tim Moon committed
337
TRANSFORMER_ENGINE_TYPE_NAME(uint8_t)
338
TRANSFORMER_ENGINE_TYPE_NAME(int16_t)
Tim Moon's avatar
Tim Moon committed
339
TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
340
TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
Tim Moon's avatar
Tim Moon committed
341
342
343
344
345
TRANSFORMER_ENGINE_TYPE_NAME(float)
TRANSFORMER_ENGINE_TYPE_NAME(half)
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
346
TRANSFORMER_ENGINE_TYPE_NAME(int8_t)
347
348
349
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif
350
351
352
#if FP4_TYPE_SUPPORTED
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp4_e2m1)
#endif
Tim Moon's avatar
Tim Moon committed
353
354
#undef TRANSFORMER_ENGINE_TYPE_NAME

355
356
357
template <typename T>
struct TypeExtrema;

358
359
360
361
#if FP4_TYPE_SUPPORTED
template <>
struct TypeExtrema<fp4e2m1> {
  static constexpr float max = 6.0f;
362
  static constexpr float max_inverse = 1.0 / max;
363
364
365
};
#endif

366
367
368
template <>
struct TypeExtrema<fp8e4m3> {
  static constexpr float max = 448.0f;
369
  static constexpr float max_inverse = 1.0 / max;
370
371
};

yuguo's avatar
yuguo committed
372
373
374
375
376
template <>
struct TypeExtrema<int8> {
  static constexpr float max = 127.0f;
};

377
378
379
template <>
struct TypeExtrema<fp8e5m2> {
  static constexpr float max = 57344.0f;
380
  static constexpr float max_inverse = 1.0 / max;
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
};

template <>
struct TypeExtrema<bf16> {
  // Hex float format of 1.(7 bits of 1) * 2 ^ 127
  static constexpr float max = 0x1.FEp127;
};

template <>
struct TypeExtrema<fp16> {
  // Hex float format of 1.(10 bits of 1) * 2 ^ 15
  static constexpr float max = 0x1.FFCp15;
};

template <typename T>
struct TypeExtrema {
  static constexpr float max = std::numeric_limits<T>::max();
};

Tim Moon's avatar
Tim Moon committed
400
}  // namespace detail
Przemek Tredak's avatar
Przemek Tredak committed
401

402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
template <typename T>
struct BitsNumber;

#if FP4_TYPE_SUPPORTED
template <>
struct BitsNumber<fp4e2m1> {
  static constexpr size_t num_bits = 4;
};
#endif

template <typename T>
struct BitsNumber {
  static constexpr size_t num_bits = 8 * sizeof(T);
};

Przemek Tredak's avatar
Przemek Tredak committed
417
template <typename T>
418
struct TypeInfo {
419
#if FP4_TYPE_SUPPORTED
wenjh's avatar
wenjh committed
420
  using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8
421
422
423
424
#if CUDA_VERSION >= 12080
                           ,
                           fp8e8m0
#endif
wenjh's avatar
wenjh committed
425
426
                           ,
                           fp4e2m1
427
                           >;
428
#else
wenjh's avatar
wenjh committed
429
  using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8
430
431
432
433
434
#if CUDA_VERSION >= 12080
                           ,
                           fp8e8m0
#endif
                           >;
435
#endif
436
437
438

  template <typename U, DType current>
  struct Helper {
Przemek Tredak's avatar
Przemek Tredak committed
439
    constexpr static DType getType() {
440
441
442
443
444
445
      constexpr int i = static_cast<int>(current);
      if (std::is_same<U, typename std::tuple_element<i, types>::type>::value) {
        return current;
      } else {
        return Helper<U, static_cast<DType>(i + 1)>::getType();
      }
Przemek Tredak's avatar
Przemek Tredak committed
446
    }
447
448
449
450
451
452
  };

  template <typename U>
  struct Helper<U, DType::kNumTypes> {
    constexpr static DType getType() { return DType::kNumTypes; }
  };
Przemek Tredak's avatar
Przemek Tredak committed
453

454
455
456
457
458
459
  template <typename U>
  constexpr static DType getType() {
    return Helper<U, DType::kByte>::getType();
  }

  constexpr static DType dtype = getType<T>();
460
  constexpr static size_t size = BitsNumber<T>::num_bits;
461
  constexpr static float max_finite_value = detail::TypeExtrema<T>::max;
462
  constexpr static const char *name = detail::type_name<T>();
Przemek Tredak's avatar
Przemek Tredak committed
463
464
};

465
466
467
468
469
470
471
472
473
474
#if FP4_TYPE_SUPPORTED
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
  case DType::kFloat4E2M1: {              \
    using type = fp4e2m1;                 \
    { __VA_ARGS__ }                       \
  } break;
#else
#define SWITCH_FP4_TYPE_HANDLE(type, ...)  // do nothing
#endif

Przemek Tredak's avatar
Przemek Tredak committed
475
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
476
477
478
479
480
481
  switch (dtype) {                                           \
    using namespace transformer_engine;                      \
    case DType::kByte: {                                     \
      using type = unsigned char;                            \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
482
483
484
485
    case DType::kInt16: {                                    \
      using type = int16_t;                                  \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
486
    case DType::kInt32: {                                    \
487
488
489
490
491
      using type = int32_t;                                  \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kInt64: {                                    \
      using type = int64_t;                                  \
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kFloat32: {                                  \
      using type = float;                                    \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kFloat16: {                                  \
      using type = fp16;                                     \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kBFloat16: {                                 \
      using type = bf16;                                     \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kFloat8E4M3: {                               \
      using type = fp8e4m3;                                  \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kFloat8E5M2: {                               \
      using type = fp8e5m2;                                  \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
514
515
516
517
    case DType::kFloat8E8M0: {                               \
      using type = byte;                                     \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
518
      SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__)              \
519
520
521
    default:                                                 \
      NVTE_ERROR("Invalid type.");                           \
  }
Przemek Tredak's avatar
Przemek Tredak committed
522

523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \
  switch (dtype) {                                             \
    using namespace transformer_engine;                        \
    case DType::kFloat32: {                                    \
      using type = float;                                      \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat16: {                                    \
      using type = fp16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kBFloat16: {                                   \
      using type = bf16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat8E4M3: {                                 \
      using type = fp8e4m3;                                    \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat8E5M2: {                                 \
      using type = fp8e5m2;                                    \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    default:                                                   \
      NVTE_ERROR("Invalid type.");                             \
  }

Przemek Tredak's avatar
Przemek Tredak committed
550
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
  switch (dtype) {                                              \
    using namespace transformer_engine;                         \
    case DType::kFloat32: {                                     \
      using type = float;                                       \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    case DType::kFloat16: {                                     \
      using type = fp16;                                        \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    case DType::kBFloat16: {                                    \
      using type = bf16;                                        \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    case DType::kFloat8E5M2: {                                  \
      using type = fp8e5m2;                                     \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    case DType::kFloat8E4M3: {                                  \
      using type = fp8e4m3;                                     \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    default:                                                    \
      NVTE_ERROR("Invalid type.");                              \
  }
Przemek Tredak's avatar
Przemek Tredak committed
576

577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8(dtype, type, ...) \
  switch (dtype) {                                                        \
    using namespace transformer_engine;                                   \
    case DType::kFloat32: {                                               \
      using type = float;                                                 \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kFloat16: {                                               \
      using type = fp16;                                                  \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kBFloat16: {                                              \
      using type = bf16;                                                  \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kFloat8E5M2: {                                            \
      using type = fp8e5m2;                                               \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kFloat8E4M3: {                                            \
      using type = fp8e4m3;                                               \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kInt8: {                                                  \
      using type = int8;                                                  \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    default:                                                              \
      NVTE_ERROR("Invalid type.");                                        \
  }

608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \
  switch (dtype) {                                                   \
    using namespace transformer_engine;                              \
    case DType::kFloat32: {                                          \
      using type = float;                                            \
      { __VA_ARGS__ }                                                \
    } break;                                                         \
    case DType::kFloat16: {                                          \
      using type = fp16;                                             \
      { __VA_ARGS__ }                                                \
    } break;                                                         \
    case DType::kBFloat16: {                                         \
      using type = bf16;                                             \
      { __VA_ARGS__ }                                                \
    } break;                                                         \
    default:                                                         \
      NVTE_ERROR("Invalid type.");                                   \
  }

627
628
629
630
631
632
633
634
635
636
637
638
// Add a pack_size argument to select the packed type for FP4
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \
  switch (dtype) {                                                             \
    using namespace transformer_engine;                                        \
    case DType::kFloat4E2M1: {                                                 \
      using type = __nv_fp4x2_storage_t;                                       \
      { __VA_ARGS__ }                                                          \
    } break;                                                                   \
    default:                                                                   \
      NVTE_ERROR("Invalid type.");                                             \
  }

Przemek Tredak's avatar
Przemek Tredak committed
639
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \
640
641
642
643
644
645
646
647
648
649
650
651
652
  switch (dtype) {                                               \
    using namespace transformer_engine;                          \
    case DType::kFloat8E5M2: {                                   \
      using type = fp8e5m2;                                      \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    case DType::kFloat8E4M3: {                                   \
      using type = fp8e4m3;                                      \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    default:                                                     \
      NVTE_ERROR("Invalid type.");                               \
  }
Przemek Tredak's avatar
Przemek Tredak committed
653

yuguo's avatar
yuguo committed
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
#define TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(dtype, type, ...)    \
  switch (dtype) {                                               \
    using namespace transformer_engine;                          \
    case DType::kFloat8E5M2: {                                   \
      using type = fp8e5m2;                                      \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    case DType::kFloat8E4M3: {                                   \
      using type = fp8e4m3;                                      \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    case DType::kInt8: {                                         \
      using type = int8;                                         \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    default:                                                     \
      NVTE_ERROR("Invalid type.");                               \
  }

673
#if FP4_TYPE_SUPPORTED
Przemek Tredak's avatar
Przemek Tredak committed
674
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
  switch (dtype) {                                             \
    using namespace transformer_engine;                        \
    case DType::kFloat32: {                                    \
      using type = float;                                      \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat16: {                                    \
      using type = fp16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kBFloat16: {                                   \
      using type = bf16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat8E5M2:                                   \
    case DType::kFloat8E4M3: {                                 \
      NVTE_ERROR("FP8 type not instantiated for input.");      \
    } break;                                                   \
693
694
695
    case DType::kFloat4E2M1: {                                 \
      NVTE_ERROR("FP4 type not instantiated for input.");      \
    } break;                                                   \
696
697
698
    default:                                                   \
      NVTE_ERROR("Invalid type.");                             \
  }
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
#else
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
  switch (dtype) {                                             \
    using namespace transformer_engine;                        \
    case DType::kFloat32: {                                    \
      using type = float;                                      \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat16: {                                    \
      using type = fp16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kBFloat16: {                                   \
      using type = bf16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat8E5M2:                                   \
    case DType::kFloat8E4M3: {                                 \
      NVTE_ERROR("FP8 type not instantiated for input.");      \
    } break;                                                   \
    default:                                                   \
      NVTE_ERROR("Invalid type.");                             \
  }
#endif
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739

#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \
  switch (dtype) {                                             \
    using namespace transformer_engine;                        \
    case DType::kFloat16: {                                    \
      using type = fp16;                                       \
      __VA_ARGS__;                                             \
      break;                                                   \
    }                                                          \
    case DType::kBFloat16: {                                   \
      using type = bf16;                                       \
      __VA_ARGS__;                                             \
      break;                                                   \
    }                                                          \
    default:                                                   \
      NVTE_ERROR("Invalid type for 16 bit.");                  \
  }
740

741
742
743
744
745
746
747
748
749
750
751
752
753
#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \
  switch (SCALE_DIM) {                                              \
    case 1: {                                                       \
      constexpr size_t DIM = 1;                                     \
      { __VA_ARGS__ }                                               \
    } break;                                                        \
    case 32: {                                                      \
      constexpr size_t DIM = 32;                                    \
      { __VA_ARGS__ }                                               \
    } break;                                                        \
    default: {                                                      \
      NVTE_ERROR("Invalid size of the MX scaling factor.");         \
    }                                                               \
754
  }
755

756
757
758
759
760
761
762
763
764
#define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \
  if (CONDITION) {                                                \
    constexpr bool FLAG = true;                                   \
    { __VA_ARGS__ }                                               \
  } else {                                                        \
    constexpr bool FLAG = false;                                  \
    { __VA_ARGS__ }                                               \
  }

765
////////////////////////////////////////////////////////////////////////////////////////////////////
Przemek Tredak's avatar
Przemek Tredak committed
766

767
inline int log2_ceil(int value) {
768
769
770
  int log2_value = 0;
  while ((1 << log2_value) < value) ++log2_value;
  return log2_value;
771
772
}

773
774
775
776
777
778
779
780
template <size_t B>
inline size_t alignTo(size_t x) {
  size_t r = x % B;
  if (r == 0) return x;

  return x + B - r;
}

Przemek Tredak's avatar
Przemek Tredak committed
781
782
783
784
785
786
787
788
789
template <typename T>
struct is_fp8 : std::false_type {};

template <>
struct is_fp8<fp8e4m3> : std::true_type {};

template <>
struct is_fp8<fp8e5m2> : std::true_type {};

yuguo's avatar
yuguo committed
790
791
792
793
794
795
template <typename T>
struct is_int8 : std::false_type {};

template <>
struct is_int8<int8> : std::true_type {};

796
797
798
799
800
801
802
803
template <typename T>
struct is_fp4 : std::false_type {};

#if FP4_TYPE_SUPPORTED
template <>
struct is_fp4<fp4e2m1> : std::true_type {};
#endif

804
805
806
807
808
809
// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;

810
// Alignment requirements for the Tensor Memory Accelerator (TMA)
811
812
constexpr size_t TMA_GMEM_ALIGNMENT = 16;    // global memory address alignment
constexpr size_t TMA_SHMEM_ALIGNMENT = 128;  // shared memory address alignment
813
814
815
816
817
818
819
820
821

inline bool is_aligned_ptr(const void *ptr, size_t alignment) {
  return reinterpret_cast<uintptr_t>(ptr) % alignment == 0;
}

inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) {
  return is_aligned_ptr(static_cast<const void *>(t.data.dptr), alignment);
}

Przemek Tredak's avatar
Przemek Tredak committed
822
size_t typeToSize(const DType type);
823
824
825
826
827
size_t typeToNumBits(const DType type);

size_t get_buffer_size_bytes(const size_t N, const DType buffer_dtype);
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
                             const DType buffer_dtype);
Przemek Tredak's avatar
Przemek Tredak committed
828

829
void CheckNoopTensor(const Tensor &t, const std::string &name);
830
831
832
void CheckInputTensor(const Tensor &t, const std::string &name);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);

833
834
835
836
837
838
839
/*! \brief Update a tensor's FP8 scale-inverse
 *
 * The FP8 scale-inverse (dequantization scaling factor) is updated
 * with the reciprocal of the FP8 scale (quantization scaling factor).
 */
void update_tensor_scale_inv(Tensor *t, cudaStream_t stream);

840
#define NVTE_API_CALL(api_name) \
841
  transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name);
842

843
844
void checkCuDriverContext(CUstream stream);

yuguo's avatar
yuguo committed
845
#ifndef __HIP_PLATFORM_AMD__
846
847
848
CUtensorMapDataType get_CUtensorMapDataType(DType dtype);

// Set up parameters to create TMA descriptor.
849
850
851
852
853
void create_2D_tensor_map(
    CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY,
    const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX,
    const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits,
    const CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
yuguo's avatar
yuguo committed
854
#endif
855
856
857

bool is_supported_by_CC_100();

858
859
860
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
                                                        size_t outer_size, size_t inner_size);

861
862
Tensor *convertNVTETensor(const NVTETensor tensor);
Tensor *convertNVTETensorCheck(const NVTETensor tensor);
Przemek Tredak's avatar
Przemek Tredak committed
863
864
865
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_COMMON_H_