common.h 34.7 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_
9
#include "util/system.h"
yuguo's avatar
yuguo committed
10
#ifndef __HIP_PLATFORM_AMD__
11
#include <cudaTypedefs.h>
wenjh's avatar
wenjh committed
12
13
#else
#define CUDA_VERSION 0
yuguo's avatar
yuguo committed
14
#endif
15

16
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
17
18
19
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
20
21
22
23
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif

24
25
26
#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h>

27
#include <cstdint>
Przemek Tredak's avatar
Przemek Tredak committed
28
29
30
31
#include <functional>
#include <stdexcept>
#include <string>
#include <tuple>
Tim Moon's avatar
Tim Moon committed
32
33
#include <type_traits>
#include <unordered_map>
Przemek Tredak's avatar
Przemek Tredak committed
34
#include <vector>
Tim Moon's avatar
Tim Moon committed
35
36

#include "./nvtx.h"
37
#include "./util/cuda_driver.h"
Tim Moon's avatar
Tim Moon committed
38
#include "./util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
39
40
41

namespace transformer_engine {

42
43
44
std::string to_string(const DType type);
std::string to_string(const NVTEScalingMode &mode);

45
46
47
48
inline int blockwise_fp8_block_len() {
  return ::transformer_engine::getenv<int>("NVTE_BLOCKWISE_FP8_BLOCK_LEN", 128);
}

49
50
51
52
53
54
55
56
57
58
inline bool is_tensor_scaling(const NVTEScalingMode &mode) {
  return mode == NVTE_DELAYED_TENSOR_SCALING;
}

inline bool is_block_scaling(const NVTEScalingMode &mode) { return !is_tensor_scaling(mode); }

inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) {
  return mode == NVTE_DELAYED_TENSOR_SCALING;
}

59
60
61
62
inline bool is_nvfp4_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }

inline bool is_mxfp8_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }

63
64
inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }

65
66
inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
inline size_t product(const std::vector<size_t> &shape, const size_t begin, const size_t end) {
  NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ",
             end, " in a vector with ", shape.size(), " entries");
  size_t ret = 1;
  for (size_t i = begin; i < end; ++i) {
    ret *= shape[i];
  }
  return ret;
}

inline size_t product(const std::vector<size_t> &shape) {
  size_t ret = 1;
  for (const auto &elem : shape) {
    ret *= elem;
  }
  return ret;
}

85
86
87
88
89
struct SimpleTensor {
  void *dptr;
  std::vector<size_t> shape;
  DType dtype;

90
91
  SimpleTensor(void *dptr, const std::vector<size_t> &shape, DType dtype)
      : dptr(dptr), shape(shape), dtype(dtype) {}
92
93
94
95
96
97

  SimpleTensor(const NVTEBasicTensor &tensor)  // NOLINT
      : dptr(tensor.data_ptr),
        shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim),
        dtype(static_cast<DType>(tensor.dtype)) {}

98
  SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {}
99
100

  operator NVTEBasicTensor() const {
101
102
    return {dptr, static_cast<NVTEDType>(dtype),
            nvte_make_shape(this->shape.data(), this->shape.size())};
103
104
  }

105
  size_t numel() const {
106
107
108
109
110
111
    size_t acc = 1;
    for (const auto &dim : shape) {
      acc *= dim;
    }
    return acc;
  }
112
113
114
115
116
117

  void clear() {
    dptr = nullptr;
    shape.resize(0);
    dtype = DType::kFloat32;
  }
118
};
Przemek Tredak's avatar
Przemek Tredak committed
119

120
struct Tensor {
121
 public:
122
  SimpleTensor data;
123
  SimpleTensor columnwise_data;
124
  SimpleTensor amax;
125
  SimpleTensor columnwise_amax;
126
127
  SimpleTensor scale;
  SimpleTensor scale_inv;
128
129
130
  SimpleTensor columnwise_scale_inv;

  NVTEScalingMode scaling_mode;
131
  NVTETensor nvte_tensor;
132

133
134
  Tensor()
      : data(),
135
        columnwise_data(),
136
        amax(nullptr, {1}, DType::kFloat32),
137
        columnwise_amax(nullptr, {1}, DType::kFloat32),
138
        scale(nullptr, {1}, DType::kFloat32),
139
140
        scale_inv(nullptr, {1}, DType::kFloat32),
        columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
141
142
143
144
145
146
147
        scaling_mode(NVTE_DELAYED_TENSOR_SCALING),
        nvte_tensor(0) {}

  void clear() {
    data.clear();
    columnwise_data.clear();
    amax.clear();
148
    columnwise_amax.clear();
149
150
151
152
153
154
155
    scale.clear();
    scale_inv.clear();
    columnwise_scale_inv.clear();
    scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
  }

  explicit operator NVTETensor() const noexcept { return nvte_tensor; }
156

157
  size_t numel() const {
158
    size_t acc = 1;
159
    for (const auto dim : shape()) {
160
161
162
163
164
165
166
      acc *= dim;
    }
    return acc;
  }

  bool has_data() const noexcept { return data.dptr != nullptr; }

167
168
169
170
  // Check for size (not just pointer) for 0-dim or no token cases.
  bool has_columnwise_data() const noexcept {
    return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0;
  }
171
172
173
174
175
176
177
178

  DType dtype() const {
    if (has_data()) return data.dtype;
    if (has_columnwise_data()) return columnwise_data.dtype;
    // Fallback, used e.g. in workspace
    return data.dtype;
  }

179
180
181
182
183
184
185
186
  size_t dim() const {
    if (!has_data() && has_columnwise_data()) {
      return columnwise_data.shape.size();
    } else {
      return data.shape.size();
    }
  }

187
188
189
190
191
192
193
194
  std::vector<size_t> shape() const {
    /* Note: We sometimes experience spurious compiler errors
     * (-Wstringop-overflow) from this function. It appears that GCC
     * has some bugs with std::vector (see
     * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569).
     */
    switch (scaling_mode) {
      case NVTE_DELAYED_TENSOR_SCALING:
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
      case NVTE_NVFP4_1D_SCALING: {
        // Choose data buffer based on whether it is initialized
        // Note: Uninitialized buffers currently have shape=[].
        // However, this is logically incorrect. 0-D tensors have 1
        // entry, and uninitialized tensors should have shape=[0].
        bool use_columnwise_shape = false;
        if (data.dptr != nullptr) {
          use_columnwise_shape = false;
        } else if (columnwise_data.dptr != nullptr) {
          use_columnwise_shape = true;
        } else if (data.shape.size() != 0) {
          use_columnwise_shape = false;
        } else if (columnwise_data.shape.size() != 0) {
          use_columnwise_shape = true;
        }

        // Infer shape based on data
        if (use_columnwise_shape) {
          // Column-wise data is transposed
214
215
          std::vector<size_t> ret;
          if (!columnwise_data.shape.empty()) {
216
            ret.reserve(columnwise_data.shape.size());
217
218
219
220
221
222
223
            for (size_t i = 1; i < columnwise_data.shape.size(); i++) {
              ret.push_back(columnwise_data.shape[i]);
            }
            ret.push_back(columnwise_data.shape.front());
          }
          return ret;
        }
224
225
        return data.shape;
      }
226
227
228
229
230
231
232
      case NVTE_MXFP8_1D_SCALING:
        if (!has_data() && has_columnwise_data()) {
          return columnwise_data.shape;
        } else {
          return data.shape;
        }
        break;
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
      case NVTE_BLOCK_SCALING_1D:
      case NVTE_BLOCK_SCALING_2D: {
        if (!has_data() && has_columnwise_data()) {
          std::vector<size_t> shape;
          size_t ndim = columnwise_data.shape.size();
          shape.reserve(ndim);
          for (size_t i = 0; i + 1 < ndim; ++i) {
            shape.push_back(columnwise_data.shape[i + 1]);
          }
          if (ndim > 0) {
            shape.push_back(columnwise_data.shape[0]);
          }
          return shape;
        } else {
          // NOTE: We may have removed the data pointer from
          // data by setting usage. In that case, we return
          // the non-null shape. It is our best guess at the most
          // recent shape.
          return data.shape;
        }
        break;
      }
255
256
257
258
259
260
      default:
        NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\"");
        return {};
    }
  }

261
262
263
264
265
266
  /*! Matrix height after tensor is flattened to 2D
   *
   * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
   * as a (D1*D2*...*D(n-1), Dn) matrix.
   */
  size_t flat_first_dim() const {
267
268
269
270
271
    const auto &full_shape = shape();
    size_t ret = 1;
    if (!full_shape.empty()) {
      for (size_t i = 0; i < full_shape.size() - 1; i++) {
        ret *= full_shape[i];
272
273
      }
    }
274
    return ret;
275
276
277
278
279
280
281
282
  }

  /*! Matrix width after tensor is flattened to 2D
   *
   * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
   * as a (D1*D2*...*D(n-1), Dn) matrix.
   */
  size_t flat_last_dim() const {
283
284
285
286
287
    const auto &full_shape = shape();
    if (full_shape.empty()) {
      return 1;
    } else {
      return full_shape.back();
288
289
    }
  }
Przemek Tredak's avatar
Przemek Tredak committed
290
291
};

292
293
294
struct QuantizationConfig {
  bool force_pow_2_scales = false;
  float amax_epsilon = 0.0f;
295
  NVTETensor noop_tensor = nullptr;
296
297
  Float8BlockScaleTensorFormat float8_block_scale_tensor_format =
      Float8BlockScaleTensorFormat::GEMM_READY;
298
299
300
  NVTETensor rng_state = nullptr;
  bool nvfp4_2d_quantization = false;
  bool stochastic_rounding = false;
301
302

  static constexpr size_t attr_sizes[] = {
303
304
305
306
307
308
309
      sizeof(bool),                          // force_pow_2_scales
      sizeof(float),                         // amax_epsilon
      sizeof(NVTETensor),                    // noop_tensor
      sizeof(Float8BlockScaleTensorFormat),  // float8_block_scale_tensor_format
      sizeof(NVTETensor),                    // rng_seed and offset
      sizeof(bool),                          // nvfp4_2d_quantization
      sizeof(bool)                           // stochastic_rounding
310
311
312
  };
};

313
314
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t);

Przemek Tredak's avatar
Przemek Tredak committed
315
316
template <typename T>
constexpr T DIVUP(const T &x, const T &y) {
317
  return (((x) + ((y)-1)) / (y));
Przemek Tredak's avatar
Przemek Tredak committed
318
319
}

320
321
322
323
324
325
326
template <typename T1, typename T2>
constexpr __device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(const T1 &N, const T2 &M) {
  static_assert(std::is_integral<T1>::value && std::is_integral<T2>::value,
                "Integral type required.");
  return DIVUP(static_cast<uint64_t>(N), static_cast<uint64_t>(M)) * M;
}

Przemek Tredak's avatar
Przemek Tredak committed
327
using byte = uint8_t;
328
using int16 = int16_t;
Przemek Tredak's avatar
Przemek Tredak committed
329
using int32 = int32_t;
330
using int64 = int64_t;
Przemek Tredak's avatar
Przemek Tredak committed
331
332
using fp32 = float;
using fp16 = half;
yuguo's avatar
yuguo committed
333
using int8 = int8_t;
Przemek Tredak's avatar
Przemek Tredak committed
334
335
336
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
337
338
339
#if CUDA_VERSION >= 12080
using fp8e8m0 = __nv_fp8_e8m0;
#endif
340
341
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
342
343
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
344
#endif
345
using e8m0_t = uint8_t;
Przemek Tredak's avatar
Przemek Tredak committed
346

Tim Moon's avatar
Tim Moon committed
347
348
349
350
namespace detail {

template <typename T>
constexpr inline const char *type_name() noexcept;
351
352
353
354
355
#define TRANSFORMER_ENGINE_TYPE_NAME(T)                  \
  template <>                                            \
  inline constexpr const char *type_name<T>() noexcept { \
    return #T;                                           \
  }
Tim Moon's avatar
Tim Moon committed
356
TRANSFORMER_ENGINE_TYPE_NAME(uint8_t)
357
TRANSFORMER_ENGINE_TYPE_NAME(int16_t)
Tim Moon's avatar
Tim Moon committed
358
TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
359
TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
Tim Moon's avatar
Tim Moon committed
360
361
362
363
364
TRANSFORMER_ENGINE_TYPE_NAME(float)
TRANSFORMER_ENGINE_TYPE_NAME(half)
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
365
TRANSFORMER_ENGINE_TYPE_NAME(int8_t)
366
367
368
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif
369
370
371
#if FP4_TYPE_SUPPORTED
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp4_e2m1)
#endif
Tim Moon's avatar
Tim Moon committed
372
373
#undef TRANSFORMER_ENGINE_TYPE_NAME

374
375
376
template <typename T>
struct TypeExtrema;

377
378
379
380
#if FP4_TYPE_SUPPORTED
template <>
struct TypeExtrema<fp4e2m1> {
  static constexpr float max = 6.0f;
381
  static constexpr float max_inverse = 1.0 / max;
382
383
384
};
#endif

385
386
387
template <>
struct TypeExtrema<fp8e4m3> {
  static constexpr float max = 448.0f;
388
  static constexpr float max_inverse = 1.0 / max;
389
390
};

yuguo's avatar
yuguo committed
391
392
393
394
395
template <>
struct TypeExtrema<int8> {
  static constexpr float max = 127.0f;
};

396
397
398
template <>
struct TypeExtrema<fp8e5m2> {
  static constexpr float max = 57344.0f;
399
  static constexpr float max_inverse = 1.0 / max;
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
};

template <>
struct TypeExtrema<bf16> {
  // Hex float format of 1.(7 bits of 1) * 2 ^ 127
  static constexpr float max = 0x1.FEp127;
};

template <>
struct TypeExtrema<fp16> {
  // Hex float format of 1.(10 bits of 1) * 2 ^ 15
  static constexpr float max = 0x1.FFCp15;
};

template <typename T>
struct TypeExtrema {
  static constexpr float max = std::numeric_limits<T>::max();
};

Tim Moon's avatar
Tim Moon committed
419
}  // namespace detail
Przemek Tredak's avatar
Przemek Tredak committed
420

421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
template <typename T>
struct BitsNumber;

#if FP4_TYPE_SUPPORTED
template <>
struct BitsNumber<fp4e2m1> {
  static constexpr size_t num_bits = 4;
};
#endif

template <typename T>
struct BitsNumber {
  static constexpr size_t num_bits = 8 * sizeof(T);
};

Przemek Tredak's avatar
Przemek Tredak committed
436
template <typename T>
437
struct TypeInfo {
438
#if FP4_TYPE_SUPPORTED
wenjh's avatar
wenjh committed
439
  using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8
440
441
442
443
#if CUDA_VERSION >= 12080
                           ,
                           fp8e8m0
#endif
wenjh's avatar
wenjh committed
444
445
                           ,
                           fp4e2m1
446
                           >;
447
#else
wenjh's avatar
wenjh committed
448
  using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8
449
450
451
452
453
#if CUDA_VERSION >= 12080
                           ,
                           fp8e8m0
#endif
                           >;
454
#endif
455
456
457

  template <typename U, DType current>
  struct Helper {
Przemek Tredak's avatar
Przemek Tredak committed
458
    constexpr static DType getType() {
459
460
461
462
463
464
      constexpr int i = static_cast<int>(current);
      if (std::is_same<U, typename std::tuple_element<i, types>::type>::value) {
        return current;
      } else {
        return Helper<U, static_cast<DType>(i + 1)>::getType();
      }
Przemek Tredak's avatar
Przemek Tredak committed
465
    }
466
467
468
469
470
471
  };

  template <typename U>
  struct Helper<U, DType::kNumTypes> {
    constexpr static DType getType() { return DType::kNumTypes; }
  };
Przemek Tredak's avatar
Przemek Tredak committed
472

473
474
475
476
477
478
  template <typename U>
  constexpr static DType getType() {
    return Helper<U, DType::kByte>::getType();
  }

  constexpr static DType dtype = getType<T>();
479
  constexpr static size_t size = BitsNumber<T>::num_bits;
480
  constexpr static float max_finite_value = detail::TypeExtrema<T>::max;
481
  constexpr static const char *name = detail::type_name<T>();
Przemek Tredak's avatar
Przemek Tredak committed
482
483
};

484
485
486
487
488
489
490
491
492
493
#if FP4_TYPE_SUPPORTED
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
  case DType::kFloat4E2M1: {              \
    using type = fp4e2m1;                 \
    { __VA_ARGS__ }                       \
  } break;
#else
#define SWITCH_FP4_TYPE_HANDLE(type, ...)  // do nothing
#endif

Przemek Tredak's avatar
Przemek Tredak committed
494
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
495
496
497
498
499
500
  switch (dtype) {                                           \
    using namespace transformer_engine;                      \
    case DType::kByte: {                                     \
      using type = unsigned char;                            \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
501
502
503
504
    case DType::kInt16: {                                    \
      using type = int16_t;                                  \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
505
    case DType::kInt32: {                                    \
506
507
508
509
510
      using type = int32_t;                                  \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kInt64: {                                    \
      using type = int64_t;                                  \
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kFloat32: {                                  \
      using type = float;                                    \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kFloat16: {                                  \
      using type = fp16;                                     \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kBFloat16: {                                 \
      using type = bf16;                                     \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kFloat8E4M3: {                               \
      using type = fp8e4m3;                                  \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kFloat8E5M2: {                               \
      using type = fp8e5m2;                                  \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
533
534
535
536
    case DType::kFloat8E8M0: {                               \
      using type = byte;                                     \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
537
      SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__)              \
538
539
540
    default:                                                 \
      NVTE_ERROR("Invalid type.");                           \
  }
Przemek Tredak's avatar
Przemek Tredak committed
541

542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \
  switch (dtype) {                                             \
    using namespace transformer_engine;                        \
    case DType::kFloat32: {                                    \
      using type = float;                                      \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat16: {                                    \
      using type = fp16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kBFloat16: {                                   \
      using type = bf16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat8E4M3: {                                 \
      using type = fp8e4m3;                                    \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat8E5M2: {                                 \
      using type = fp8e5m2;                                    \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    default:                                                   \
      NVTE_ERROR("Invalid type.");                             \
  }

Przemek Tredak's avatar
Przemek Tredak committed
569
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
  switch (dtype) {                                              \
    using namespace transformer_engine;                         \
    case DType::kFloat32: {                                     \
      using type = float;                                       \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    case DType::kFloat16: {                                     \
      using type = fp16;                                        \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    case DType::kBFloat16: {                                    \
      using type = bf16;                                        \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    case DType::kFloat8E5M2: {                                  \
      using type = fp8e5m2;                                     \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    case DType::kFloat8E4M3: {                                  \
      using type = fp8e4m3;                                     \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    default:                                                    \
      NVTE_ERROR("Invalid type.");                              \
  }
Przemek Tredak's avatar
Przemek Tredak committed
595

596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8(dtype, type, ...) \
  switch (dtype) {                                                        \
    using namespace transformer_engine;                                   \
    case DType::kFloat32: {                                               \
      using type = float;                                                 \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kFloat16: {                                               \
      using type = fp16;                                                  \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kBFloat16: {                                              \
      using type = bf16;                                                  \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kFloat8E5M2: {                                            \
      using type = fp8e5m2;                                               \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kFloat8E4M3: {                                            \
      using type = fp8e4m3;                                               \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kInt8: {                                                  \
      using type = int8;                                                  \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    default:                                                              \
      NVTE_ERROR("Invalid type.");                                        \
  }

627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \
  switch (dtype) {                                                   \
    using namespace transformer_engine;                              \
    case DType::kFloat32: {                                          \
      using type = float;                                            \
      { __VA_ARGS__ }                                                \
    } break;                                                         \
    case DType::kFloat16: {                                          \
      using type = fp16;                                             \
      { __VA_ARGS__ }                                                \
    } break;                                                         \
    case DType::kBFloat16: {                                         \
      using type = bf16;                                             \
      { __VA_ARGS__ }                                                \
    } break;                                                         \
    default:                                                         \
      NVTE_ERROR("Invalid type.");                                   \
  }

646
647
648
649
650
651
652
653
654
655
656
657
// Add a pack_size argument to select the packed type for FP4
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \
  switch (dtype) {                                                             \
    using namespace transformer_engine;                                        \
    case DType::kFloat4E2M1: {                                                 \
      using type = __nv_fp4x2_storage_t;                                       \
      { __VA_ARGS__ }                                                          \
    } break;                                                                   \
    default:                                                                   \
      NVTE_ERROR("Invalid type.");                                             \
  }

Przemek Tredak's avatar
Przemek Tredak committed
658
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \
659
660
661
662
663
664
665
666
667
668
669
670
671
  switch (dtype) {                                               \
    using namespace transformer_engine;                          \
    case DType::kFloat8E5M2: {                                   \
      using type = fp8e5m2;                                      \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    case DType::kFloat8E4M3: {                                   \
      using type = fp8e4m3;                                      \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    default:                                                     \
      NVTE_ERROR("Invalid type.");                               \
  }
Przemek Tredak's avatar
Przemek Tredak committed
672

yuguo's avatar
yuguo committed
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
#define TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(dtype, type, ...)    \
  switch (dtype) {                                               \
    using namespace transformer_engine;                          \
    case DType::kFloat8E5M2: {                                   \
      using type = fp8e5m2;                                      \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    case DType::kFloat8E4M3: {                                   \
      using type = fp8e4m3;                                      \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    case DType::kInt8: {                                         \
      using type = int8;                                         \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    default:                                                     \
      NVTE_ERROR("Invalid type.");                               \
  }

692
#if FP4_TYPE_SUPPORTED
Przemek Tredak's avatar
Przemek Tredak committed
693
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
694
695
696
697
698
699
700
701
702
703
704
705
706
707
  switch (dtype) {                                             \
    using namespace transformer_engine;                        \
    case DType::kFloat32: {                                    \
      using type = float;                                      \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat16: {                                    \
      using type = fp16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kBFloat16: {                                   \
      using type = bf16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
wenjh's avatar
wenjh committed
708
    case DType::kInt8:                                         \
709
710
711
712
    case DType::kFloat8E5M2:                                   \
    case DType::kFloat8E4M3: {                                 \
      NVTE_ERROR("FP8 type not instantiated for input.");      \
    } break;                                                   \
713
714
715
    case DType::kFloat4E2M1: {                                 \
      NVTE_ERROR("FP4 type not instantiated for input.");      \
    } break;                                                   \
716
717
718
    default:                                                   \
      NVTE_ERROR("Invalid type.");                             \
  }
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
#else
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
  switch (dtype) {                                             \
    using namespace transformer_engine;                        \
    case DType::kFloat32: {                                    \
      using type = float;                                      \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat16: {                                    \
      using type = fp16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kBFloat16: {                                   \
      using type = bf16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
wenjh's avatar
wenjh committed
735
    case DType::kInt8:                                         \
736
737
738
739
740
741
742
743
    case DType::kFloat8E5M2:                                   \
    case DType::kFloat8E4M3: {                                 \
      NVTE_ERROR("FP8 type not instantiated for input.");      \
    } break;                                                   \
    default:                                                   \
      NVTE_ERROR("Invalid type.");                             \
  }
#endif
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760

#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \
  switch (dtype) {                                             \
    using namespace transformer_engine;                        \
    case DType::kFloat16: {                                    \
      using type = fp16;                                       \
      __VA_ARGS__;                                             \
      break;                                                   \
    }                                                          \
    case DType::kBFloat16: {                                   \
      using type = bf16;                                       \
      __VA_ARGS__;                                             \
      break;                                                   \
    }                                                          \
    default:                                                   \
      NVTE_ERROR("Invalid type for 16 bit.");                  \
  }
761

762
763
764
765
766
767
768
769
770
771
772
773
774
#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \
  switch (SCALE_DIM) {                                              \
    case 1: {                                                       \
      constexpr size_t DIM = 1;                                     \
      { __VA_ARGS__ }                                               \
    } break;                                                        \
    case 32: {                                                      \
      constexpr size_t DIM = 32;                                    \
      { __VA_ARGS__ }                                               \
    } break;                                                        \
    default: {                                                      \
      NVTE_ERROR("Invalid size of the MX scaling factor.");         \
    }                                                               \
775
  }
776

777
778
779
780
781
782
783
784
785
#define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \
  if (CONDITION) {                                                \
    constexpr bool FLAG = true;                                   \
    { __VA_ARGS__ }                                               \
  } else {                                                        \
    constexpr bool FLAG = false;                                  \
    { __VA_ARGS__ }                                               \
  }

786
////////////////////////////////////////////////////////////////////////////////////////////////////
Przemek Tredak's avatar
Przemek Tredak committed
787

788
inline int log2_ceil(int value) {
789
790
791
  int log2_value = 0;
  while ((1 << log2_value) < value) ++log2_value;
  return log2_value;
792
793
}

794
795
796
797
798
799
800
801
template <size_t B>
inline size_t alignTo(size_t x) {
  size_t r = x % B;
  if (r == 0) return x;

  return x + B - r;
}

Przemek Tredak's avatar
Przemek Tredak committed
802
803
804
805
806
807
808
809
810
template <typename T>
struct is_fp8 : std::false_type {};

template <>
struct is_fp8<fp8e4m3> : std::true_type {};

template <>
struct is_fp8<fp8e5m2> : std::true_type {};

yuguo's avatar
yuguo committed
811
812
813
814
815
816
template <typename T>
struct is_int8 : std::false_type {};

template <>
struct is_int8<int8> : std::true_type {};

817
818
819
820
821
822
823
824
template <typename T>
struct is_fp4 : std::false_type {};

#if FP4_TYPE_SUPPORTED
template <>
struct is_fp4<fp4e2m1> : std::true_type {};
#endif

825
826
827
828
829
830
// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;

831
// Alignment requirements for the Tensor Memory Accelerator (TMA)
832
833
constexpr size_t TMA_GMEM_ALIGNMENT = 16;    // global memory address alignment
constexpr size_t TMA_SHMEM_ALIGNMENT = 128;  // shared memory address alignment
834
835
836
837
838
839
840
841
842

inline bool is_aligned_ptr(const void *ptr, size_t alignment) {
  return reinterpret_cast<uintptr_t>(ptr) % alignment == 0;
}

inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) {
  return is_aligned_ptr(static_cast<const void *>(t.data.dptr), alignment);
}

Przemek Tredak's avatar
Przemek Tredak committed
843
size_t typeToSize(const DType type);
844
845
846
847
848
size_t typeToNumBits(const DType type);

size_t get_buffer_size_bytes(const size_t N, const DType buffer_dtype);
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
                             const DType buffer_dtype);
Przemek Tredak's avatar
Przemek Tredak committed
849

850
void CheckNoopTensor(const Tensor &t, const std::string &name);
851
852
853
void CheckInputTensor(const Tensor &t, const std::string &name);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);

854
855
856
857
858
859
860
/*! \brief Update a tensor's FP8 scale-inverse
 *
 * The FP8 scale-inverse (dequantization scaling factor) is updated
 * with the reciprocal of the FP8 scale (quantization scaling factor).
 */
void update_tensor_scale_inv(Tensor *t, cudaStream_t stream);

861
#define NVTE_API_CALL(api_name) \
862
  transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name);
863

864
865
void checkCuDriverContext(CUstream stream);

yuguo's avatar
yuguo committed
866
#ifndef __HIP_PLATFORM_AMD__
867
868
869
CUtensorMapDataType get_CUtensorMapDataType(DType dtype);

// Set up parameters to create TMA descriptor.
870
871
872
873
874
void create_2D_tensor_map(
    CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY,
    const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX,
    const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits,
    const CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
yuguo's avatar
yuguo committed
875
#endif
876
877
878

bool is_supported_by_CC_100();

879
880
881
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
                                                        size_t outer_size, size_t inner_size);

882
883
Tensor *convertNVTETensor(const NVTETensor tensor);
Tensor *convertNVTETensorCheck(const NVTETensor tensor);
Przemek Tredak's avatar
Przemek Tredak committed
884
885
886
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_COMMON_H_