common.h 40.6 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_
9
#include "util/system.h"
yuguo's avatar
yuguo committed
10
#ifndef __HIP_PLATFORM_AMD__
11
#include <cudaTypedefs.h>
wenjh's avatar
wenjh committed
12
13
#else
#define CUDA_VERSION 0
yuguo's avatar
yuguo committed
14
#endif
15

16
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
17
18
19
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
20
21
22
23
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif

24
25
26
#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h>

27
#include <cstdint>
Przemek Tredak's avatar
Przemek Tredak committed
28
29
30
31
#include <functional>
#include <stdexcept>
#include <string>
#include <tuple>
Tim Moon's avatar
Tim Moon committed
32
33
#include <type_traits>
#include <unordered_map>
Przemek Tredak's avatar
Przemek Tredak committed
34
#include <vector>
Tim Moon's avatar
Tim Moon committed
35
36

#include "./nvtx.h"
37
#include "./util/cuda_driver.h"
Tim Moon's avatar
Tim Moon committed
38
#include "./util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
39
40
41

namespace transformer_engine {

42
43
44
std::string to_string(const DType type);
std::string to_string(const NVTEScalingMode &mode);

45
46
47
48
inline int blockwise_fp8_block_len() {
  return ::transformer_engine::getenv<int>("NVTE_BLOCKWISE_FP8_BLOCK_LEN", 128);
}

49
50
51
52
53
54
55
56
57
58
inline bool is_tensor_scaling(const NVTEScalingMode &mode) {
  return mode == NVTE_DELAYED_TENSOR_SCALING;
}

inline bool is_block_scaling(const NVTEScalingMode &mode) { return !is_tensor_scaling(mode); }

inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) {
  return mode == NVTE_DELAYED_TENSOR_SCALING;
}

59
60
61
62
inline bool is_nvfp4_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }

inline bool is_mxfp8_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }

63
64
inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }

65
66
inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
inline size_t product(const std::vector<size_t> &shape, const size_t begin, const size_t end) {
  NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ",
             end, " in a vector with ", shape.size(), " entries");
  size_t ret = 1;
  for (size_t i = begin; i < end; ++i) {
    ret *= shape[i];
  }
  return ret;
}

inline size_t product(const std::vector<size_t> &shape) {
  size_t ret = 1;
  for (const auto &elem : shape) {
    ret *= elem;
  }
  return ret;
}

85
86
87
88
size_t get_buffer_size_bytes(const size_t N, const DType buffer_dtype);
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
                             const DType buffer_dtype);

89
90
91
92
93
struct SimpleTensor {
  void *dptr;
  std::vector<size_t> shape;
  DType dtype;

94
95
  SimpleTensor(void *dptr, std::vector<size_t> shape, DType dtype)
      : dptr{dptr}, shape{std::move(shape)}, dtype{dtype} {}
96
97
98
99
100
101

  SimpleTensor(const NVTEBasicTensor &tensor)  // NOLINT
      : dptr(tensor.data_ptr),
        shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim),
        dtype(static_cast<DType>(tensor.dtype)) {}

102
  SimpleTensor() : SimpleTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32) {}
103
104

  operator NVTEBasicTensor() const {
105
106
    return {dptr, static_cast<NVTEDType>(dtype),
            nvte_make_shape(this->shape.data(), this->shape.size())};
107
108
  }

109
110
111
112
113
114
115
116
117
118
119
120
121
  /*! Number of tensor elements. */
  size_t numel() const { return product(shape); }

  /*! Whether the tensor is initialized.
   *
   *  Tensors with non-trivial shapes are considered initialized. This
   *  means that there is no guarantee that the data pointer can be
   *  safely accessed.
   */
  bool has_data() const { return !(dptr == nullptr && shape.size() == 1 && shape[0] == 0); }

  /*! Buffer size in bytes. */
  size_t buffer_size_bytes() const { return get_buffer_size_bytes(numel(), dtype); }
122

123
  /*! Reset to uninitialized tensor. */
124
125
  void clear() {
    dptr = nullptr;
126
127
    shape.resize(1);
    shape[0] = 0;
128
129
    dtype = DType::kFloat32;
  }
130
};
Przemek Tredak's avatar
Przemek Tredak committed
131

132
struct Tensor {
133
 public:
134
  SimpleTensor data;
135
  SimpleTensor columnwise_data;
136
  SimpleTensor amax;
137
  SimpleTensor columnwise_amax;
138
139
  SimpleTensor scale;
  SimpleTensor scale_inv;
140
141
142
  SimpleTensor columnwise_scale_inv;

  NVTEScalingMode scaling_mode;
143
  NVTETensor nvte_tensor;
144
145
146
147
148
  /*! \brief Whether scaling factors are in format expected by GEMM
   *
   *  Only meaningful for MXFP8 and NVFP4.
   */
  bool with_gemm_swizzled_scales = false;
149

150
151
152
153
154
155
156
157
158
159
160
  /*! Map from NVTETensorParam to parameter sizes */
  static constexpr size_t attr_sizes[] = {
      sizeof(NVTEBasicTensor),  // kNVTERowwiseData
      sizeof(NVTEBasicTensor),  // kNVTEColumnwiseData
      sizeof(NVTEBasicTensor),  // kNVTEScale
      sizeof(NVTEBasicTensor),  // kNVTEAmax
      sizeof(NVTEBasicTensor),  // kNVTERowwiseScaleInv
      sizeof(NVTEBasicTensor),  // kNVTEColumnwiseScaleInv
      sizeof(NVTEBasicTensor),  // kNVTEColumnwiseAmax
      sizeof(uint8_t)           // kNVTEWithGEMMSwizzledScales
  };
161

162
  Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {}
163

164
  /*! Reset tensor data. */
165
166
167
168
  void clear() {
    data.clear();
    columnwise_data.clear();
    amax.clear();
169
    columnwise_amax.clear();
170
171
172
173
    scale.clear();
    scale_inv.clear();
    columnwise_scale_inv.clear();
    scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
174
    with_gemm_swizzled_scales = false;
175
176
177
  }

  explicit operator NVTETensor() const noexcept { return nvte_tensor; }
178

179
  /*! Number of tensor elements. */
180
  size_t numel() const {
181
182
    if (!has_data() && has_columnwise_data()) {
      return product(columnwise_data.shape);
183
    }
184
    return product(data.shape);
185
186
  }

187
188
189
190
191
192
193
  /*! Whether the tensor data buffer is not uninitialized.
   *
   *  Buffers with non-trivial shapes are considered initialized. This
   *  means that there is no guarantee that the data pointer can be
   *  safely accessed.
   */
  bool has_data() const { return data.has_data(); }
194

195
196
197
198
199
200
201
  /*! Whether the tensor column-wise data buffer is not uninitialized.
   *
   *  Buffers with non-trivial shapes are considered initialized. This
   *  means that there is no guarantee that the data pointer can be
   *  safely accessed.
   */
  bool has_columnwise_data() const { return columnwise_data.has_data(); }
202

203
  /*! Datatype of tensor elements. */
204
  DType dtype() const {
205
206
207
    if (!has_data() && has_columnwise_data()) {
      return columnwise_data.dtype;
    }
208
209
210
    return data.dtype;
  }

211
  /*! Number of tensor dimensions. */
212
213
214
215
  size_t dim() const {
    if (!has_data() && has_columnwise_data()) {
      return columnwise_data.shape.size();
    }
216
    return data.shape.size();
217
218
  }

219
220
221
222
223
224
  /*! Tensor dimensions.
   *
   *  This is the logical tensor shape. The underlying data may have a
   *  different shape, e.g. the column-wise data for some tensor
   *  formats are transposed.
   */
225
  std::vector<size_t> shape() const {
226
    // Each tensor format interprets its data differently
227
228
    switch (scaling_mode) {
      case NVTE_DELAYED_TENSOR_SCALING:
229
230
      case NVTE_BLOCK_SCALING_1D:
      case NVTE_BLOCK_SCALING_2D:
231
      case NVTE_NVFP4_1D_SCALING: {
232
233
234
        // Row-wise data shape matches tensor logical shape,
        // column-wise data shape is transpose of logical shape
        if (!has_data() && has_columnwise_data()) {
235
236
          std::vector<size_t> ret;
          if (!columnwise_data.shape.empty()) {
237
            ret.reserve(columnwise_data.shape.size());
238
239
240
241
242
243
244
            for (size_t i = 1; i < columnwise_data.shape.size(); i++) {
              ret.push_back(columnwise_data.shape[i]);
            }
            ret.push_back(columnwise_data.shape.front());
          }
          return ret;
        }
245
246
        return data.shape;
      }
247
248
249
      case NVTE_MXFP8_1D_SCALING: {
        // Row-wise and column-wise data shapes both match tensor
        // logical shape
250
251
252
        if (!has_data() && has_columnwise_data()) {
          return columnwise_data.shape;
        }
253
        return data.shape;
254
      }
255
256
257
258
259
      default:
        NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\"");
    }
  }

260
261
262
263
264
265
  /*! Matrix height after tensor is flattened to 2D
   *
   * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
   * as a (D1*D2*...*D(n-1), Dn) matrix.
   */
  size_t flat_first_dim() const {
266
267
268
269
270
    const auto &full_shape = shape();
    size_t ret = 1;
    if (!full_shape.empty()) {
      for (size_t i = 0; i < full_shape.size() - 1; i++) {
        ret *= full_shape[i];
271
272
      }
    }
273
    return ret;
274
275
276
277
278
279
280
281
  }

  /*! Matrix width after tensor is flattened to 2D
   *
   * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
   * as a (D1*D2*...*D(n-1), Dn) matrix.
   */
  size_t flat_last_dim() const {
282
283
284
285
286
    const auto &full_shape = shape();
    if (full_shape.empty()) {
      return 1;
    } else {
      return full_shape.back();
287
288
    }
  }
Przemek Tredak's avatar
Przemek Tredak committed
289
290
};

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
struct GroupedTensor {
 public:
  /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
  /*
  Grouped tensor is a collection of tensors with different shapes but the same dtype and scaling mode

  Shape Representation:
  - logical_shape: 2D shape representing the conceptual layouy, i.e. the shape when member tensors are flattened to 2D and stacked together (REQUIRED)
    + When all_same_shape(): [num_tensors * M, N] where each tensor is (M, N)
    + When varying_first_dim(): [~sum_of_first_dims, N] where N is common
    + When varying_last_dim(): [M, ~sum_of_last_dims] where M is common
    + When varying_both_dims(): [1, total_elements] (fully flattened)

  - first_dims and last_dims are OPTIONAL (empty if dimension is uniform)
    + Empty first_dims: all tensors have the same first dimension
    + Empty last_dims: all tensors have the same last dimension
    + Both empty: all tensors have identical shapes
    + Both set: each tensor has unique shape (first_dims[i], last_dims[i])

  Data Layout:
  - ALL data fields are stored as 1D flattened arrays (data, columnwise_data, scale_inv, etc.)
  - logical_shape provides the conceptual 2D interpretation
  - All data is stored on device in contiguous layout
  */

  SimpleTensor data;
  SimpleTensor columnwise_data;
  SimpleTensor scale_inv;
  SimpleTensor columnwise_scale_inv;
  SimpleTensor amax;
  SimpleTensor columnwise_amax;
  SimpleTensor scale;  // for FP8-DS only

324
325
326
  NVTEScalingMode scaling_mode;
  size_t num_tensors;

327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
  // Shape information (OPTIONAL - empty if dimension is uniform across all tensors)
  // first_dims[i] = first dimension of tensor i (empty if all tensors have same first dim)
  // last_dims[i] = last dimension of tensor i (empty if all tensors have same last dim)
  SimpleTensor first_dims;  // Device pointer to int64_t array of length num_tensors (or empty)
  SimpleTensor last_dims;   // Device pointer to int64_t array of length num_tensors (or empty)

  // Offsets for indexing into contiguous 1D layout (OPTIONAL - not needed if all_same_shape())
  // tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1)
  // Usage: tensor_i_ptr = (char*)data.dptr + tensor_offsets[i] * element_size
  // If empty and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions)
  SimpleTensor tensor_offsets;  // Device pointer to int64_t array of length num_tensors (or empty)

  // Logical shape: conceptual 2D shape of the grouped data (REQUIRED)
  // Represents how the 1D flattened data should be interpreted as 2D
  // Always 2D with positive dimensions
  NVTEShape logical_shape;

  NVTEGroupedTensor nvte_tensor;

346
347
348
349
350
351
  /*! \brief Whether scaling factors are in format expected by GEMM
   *
   *  Only meaningful for MXFP8 and NVFP4.
   */
  bool with_gemm_swizzled_scales = false;

352
353
354
355
356
357
358
359
  GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors)
      : data(),
        columnwise_data(),
        scale_inv(),
        columnwise_scale_inv(),
        amax(),
        columnwise_amax(),
        scale(),
360
        scaling_mode(scaling_mode),
361
        num_tensors(num_tensors),
362
363
364
365
        first_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
        last_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
        tensor_offsets(nullptr, std::vector<size_t>{0}, DType::kInt64),
        logical_shape(nvte_make_shape(nullptr, 1)),
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        nvte_tensor(0) {}

  explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; }

  bool has_data() const noexcept { return data.has_data(); }
  bool has_columnwise_data() const noexcept { return columnwise_data.has_data(); }

  bool all_same_first_dim() const noexcept { return !first_dims.has_data(); }
  bool all_same_last_dim() const noexcept { return !last_dims.has_data(); }
  bool all_same_shape() const noexcept { return !first_dims.has_data() && !last_dims.has_data(); }
  bool varying_both_dims() const noexcept { return first_dims.has_data() && last_dims.has_data(); }

  size_t get_common_first_dim() const {
    NVTE_CHECK(all_same_first_dim(), "First dim varies across tensors");
    NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D");
    if (all_same_shape()) {
      // When both dims are uniform: logical_shape = [num_tensors * M, N]
      return logical_shape.data[0] / num_tensors;
    } else {
      // When varying last dims but not first dim: logical_shape = [M, sum_of_last_dims]
      return logical_shape.data[0];
    }
  }
  size_t get_common_last_dim() const {
    NVTE_CHECK(all_same_last_dim(), "Last dim varies across tensors");
    NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D");
    // For both uniform and varying first dim cases: logical_shape[1] is the common last dim
    return logical_shape.data[1];
  }

  DType dtype() const {
397
398
399
    if (!has_data() && has_columnwise_data()) {
      return columnwise_data.dtype;
    }
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    return data.dtype;
  }

  void clear() {
    data.clear();
    columnwise_data.clear();
    scale_inv.clear();
    columnwise_scale_inv.clear();
    amax.clear();
    columnwise_amax.clear();
    scale.clear();
    first_dims.clear();
    last_dims.clear();
    tensor_offsets.clear();
414
    logical_shape = nvte_make_shape(nullptr, 1);
415
416
417
    num_tensors = 0;
    scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
    nvte_tensor = 0;
418
    with_gemm_swizzled_scales = false;
419
420
421
  }
};

422
423
424
struct QuantizationConfig {
  bool force_pow_2_scales = false;
  float amax_epsilon = 0.0f;
425
  NVTETensor noop_tensor = nullptr;
426
427
428
  NVTETensor rng_state = nullptr;
  bool nvfp4_2d_quantization = false;
  bool stochastic_rounding = false;
429
  bool use_fast_math = false;
430
431

  static constexpr size_t attr_sizes[] = {
432
      sizeof(uint8_t),                       // force_pow_2_scales
433
434
      sizeof(float),                         // amax_epsilon
      sizeof(NVTETensor),                    // noop_tensor
435
      sizeof(Float8BlockScaleTensorFormat),  // (deprecated)
436
      sizeof(NVTETensor),                    // rng_seed and offset
437
438
439
      sizeof(uint8_t),                       // nvfp4_2d_quantization
      sizeof(uint8_t),                       // stochastic_rounding
      sizeof(uint8_t)                        // use_fast_math
440
441
442
  };
};

443
444
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t);

Przemek Tredak's avatar
Przemek Tredak committed
445
446
template <typename T>
constexpr T DIVUP(const T &x, const T &y) {
447
  return (((x) + ((y)-1)) / (y));
Przemek Tredak's avatar
Przemek Tredak committed
448
449
}

450
451
452
453
454
455
456
template <typename T1, typename T2>
constexpr __device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(const T1 &N, const T2 &M) {
  static_assert(std::is_integral<T1>::value && std::is_integral<T2>::value,
                "Integral type required.");
  return DIVUP(static_cast<uint64_t>(N), static_cast<uint64_t>(M)) * M;
}

Przemek Tredak's avatar
Przemek Tredak committed
457
using byte = uint8_t;
458
using int16 = int16_t;
Przemek Tredak's avatar
Przemek Tredak committed
459
using int32 = int32_t;
460
using int64 = int64_t;
Przemek Tredak's avatar
Przemek Tredak committed
461
462
using fp32 = float;
using fp16 = half;
yuguo's avatar
yuguo committed
463
using int8 = int8_t;
Przemek Tredak's avatar
Przemek Tredak committed
464
465
466
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
467
468
469
#if CUDA_VERSION >= 12080
using fp8e8m0 = __nv_fp8_e8m0;
#endif
470
471
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
472
473
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
474
#endif
475
using e8m0_t = uint8_t;
Przemek Tredak's avatar
Przemek Tredak committed
476

Tim Moon's avatar
Tim Moon committed
477
478
479
480
namespace detail {

template <typename T>
constexpr inline const char *type_name() noexcept;
481
482
483
484
485
#define TRANSFORMER_ENGINE_TYPE_NAME(T)                  \
  template <>                                            \
  inline constexpr const char *type_name<T>() noexcept { \
    return #T;                                           \
  }
Tim Moon's avatar
Tim Moon committed
486
TRANSFORMER_ENGINE_TYPE_NAME(uint8_t)
487
TRANSFORMER_ENGINE_TYPE_NAME(int16_t)
Tim Moon's avatar
Tim Moon committed
488
TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
489
TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
Tim Moon's avatar
Tim Moon committed
490
491
492
493
494
TRANSFORMER_ENGINE_TYPE_NAME(float)
TRANSFORMER_ENGINE_TYPE_NAME(half)
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
495
TRANSFORMER_ENGINE_TYPE_NAME(int8_t)
496
497
498
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif
499
500
501
#if FP4_TYPE_SUPPORTED
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp4_e2m1)
#endif
Tim Moon's avatar
Tim Moon committed
502
503
#undef TRANSFORMER_ENGINE_TYPE_NAME

504
505
506
template <typename T>
struct TypeExtrema;

507
508
509
510
#if FP4_TYPE_SUPPORTED
template <>
struct TypeExtrema<fp4e2m1> {
  static constexpr float max = 6.0f;
511
  static constexpr float max_inverse = 1.0 / max;
512
513
514
};
#endif

515
516
517
template <>
struct TypeExtrema<fp8e4m3> {
  static constexpr float max = 448.0f;
518
  static constexpr float max_inverse = 1.0 / max;
519
520
};

yuguo's avatar
yuguo committed
521
522
523
524
525
template <>
struct TypeExtrema<int8> {
  static constexpr float max = 127.0f;
};

526
527
528
template <>
struct TypeExtrema<fp8e5m2> {
  static constexpr float max = 57344.0f;
529
  static constexpr float max_inverse = 1.0 / max;
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
};

template <>
struct TypeExtrema<bf16> {
  // Hex float format of 1.(7 bits of 1) * 2 ^ 127
  static constexpr float max = 0x1.FEp127;
};

template <>
struct TypeExtrema<fp16> {
  // Hex float format of 1.(10 bits of 1) * 2 ^ 15
  static constexpr float max = 0x1.FFCp15;
};

template <typename T>
struct TypeExtrema {
  static constexpr float max = std::numeric_limits<T>::max();
};

Tim Moon's avatar
Tim Moon committed
549
}  // namespace detail
Przemek Tredak's avatar
Przemek Tredak committed
550

551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
template <typename T>
struct BitsNumber;

#if FP4_TYPE_SUPPORTED
template <>
struct BitsNumber<fp4e2m1> {
  static constexpr size_t num_bits = 4;
};
#endif

template <typename T>
struct BitsNumber {
  static constexpr size_t num_bits = 8 * sizeof(T);
};

Przemek Tredak's avatar
Przemek Tredak committed
566
template <typename T>
567
struct TypeInfo {
568
#if FP4_TYPE_SUPPORTED
wenjh's avatar
wenjh committed
569
  using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8
570
571
572
573
#if CUDA_VERSION >= 12080
                           ,
                           fp8e8m0
#endif
wenjh's avatar
wenjh committed
574
575
                           ,
                           fp4e2m1
576
                           >;
577
#else
wenjh's avatar
wenjh committed
578
  using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8
579
580
581
582
583
#if CUDA_VERSION >= 12080
                           ,
                           fp8e8m0
#endif
                           >;
584
#endif
585
586
587

  template <typename U, DType current>
  struct Helper {
Przemek Tredak's avatar
Przemek Tredak committed
588
    constexpr static DType getType() {
589
      constexpr int i = static_cast<int>(current);
590
591
592
      if constexpr (i >= std::tuple_size_v<types>) {
                return DType::kNumTypes;
	    } else if (std::is_same<U, typename std::tuple_element<i, types>::type>::value) {
593
594
595
596
        return current;
      } else {
        return Helper<U, static_cast<DType>(i + 1)>::getType();
      }
Przemek Tredak's avatar
Przemek Tredak committed
597
    }
598
599
600
601
602
603
  };

  template <typename U>
  struct Helper<U, DType::kNumTypes> {
    constexpr static DType getType() { return DType::kNumTypes; }
  };
Przemek Tredak's avatar
Przemek Tredak committed
604

605
606
607
608
609
610
  template <typename U>
  constexpr static DType getType() {
    return Helper<U, DType::kByte>::getType();
  }

  constexpr static DType dtype = getType<T>();
611
  constexpr static size_t size = BitsNumber<T>::num_bits;
612
  constexpr static float max_finite_value = detail::TypeExtrema<T>::max;
613
  constexpr static const char *name = detail::type_name<T>();
Przemek Tredak's avatar
Przemek Tredak committed
614
615
};

616
617
618
619
620
621
622
623
624
625
#if FP4_TYPE_SUPPORTED
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
  case DType::kFloat4E2M1: {              \
    using type = fp4e2m1;                 \
    { __VA_ARGS__ }                       \
  } break;
#else
#define SWITCH_FP4_TYPE_HANDLE(type, ...)  // do nothing
#endif

Przemek Tredak's avatar
Przemek Tredak committed
626
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
627
628
629
630
631
632
  switch (dtype) {                                           \
    using namespace transformer_engine;                      \
    case DType::kByte: {                                     \
      using type = unsigned char;                            \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
633
634
635
636
    case DType::kInt16: {                                    \
      using type = int16_t;                                  \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
637
    case DType::kInt32: {                                    \
638
639
640
641
642
      using type = int32_t;                                  \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kInt64: {                                    \
      using type = int64_t;                                  \
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kFloat32: {                                  \
      using type = float;                                    \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kFloat16: {                                  \
      using type = fp16;                                     \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kBFloat16: {                                 \
      using type = bf16;                                     \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kFloat8E4M3: {                               \
      using type = fp8e4m3;                                  \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
    case DType::kFloat8E5M2: {                               \
      using type = fp8e5m2;                                  \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
665
666
667
668
    case DType::kFloat8E8M0: {                               \
      using type = byte;                                     \
      { __VA_ARGS__ }                                        \
    } break;                                                 \
669
      SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__)              \
670
671
672
    default:                                                 \
      NVTE_ERROR("Invalid type.");                           \
  }
Przemek Tredak's avatar
Przemek Tredak committed
673

674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \
  switch (dtype) {                                             \
    using namespace transformer_engine;                        \
    case DType::kFloat32: {                                    \
      using type = float;                                      \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat16: {                                    \
      using type = fp16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kBFloat16: {                                   \
      using type = bf16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat8E4M3: {                                 \
      using type = fp8e4m3;                                    \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat8E5M2: {                                 \
      using type = fp8e5m2;                                    \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    default:                                                   \
      NVTE_ERROR("Invalid type.");                             \
  }

Przemek Tredak's avatar
Przemek Tredak committed
701
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
  switch (dtype) {                                              \
    using namespace transformer_engine;                         \
    case DType::kFloat32: {                                     \
      using type = float;                                       \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    case DType::kFloat16: {                                     \
      using type = fp16;                                        \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    case DType::kBFloat16: {                                    \
      using type = bf16;                                        \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    case DType::kFloat8E5M2: {                                  \
      using type = fp8e5m2;                                     \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    case DType::kFloat8E4M3: {                                  \
      using type = fp8e4m3;                                     \
      { __VA_ARGS__ }                                           \
    } break;                                                    \
    default:                                                    \
      NVTE_ERROR("Invalid type.");                              \
  }
Przemek Tredak's avatar
Przemek Tredak committed
727

728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8(dtype, type, ...) \
  switch (dtype) {                                                        \
    using namespace transformer_engine;                                   \
    case DType::kFloat32: {                                               \
      using type = float;                                                 \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kFloat16: {                                               \
      using type = fp16;                                                  \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kBFloat16: {                                              \
      using type = bf16;                                                  \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kFloat8E5M2: {                                            \
      using type = fp8e5m2;                                               \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kFloat8E4M3: {                                            \
      using type = fp8e4m3;                                               \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    case DType::kInt8: {                                                  \
      using type = int8;                                                  \
      { __VA_ARGS__ }                                                     \
    } break;                                                              \
    default:                                                              \
      NVTE_ERROR("Invalid type.");                                        \
  }

759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \
  switch (dtype) {                                                   \
    using namespace transformer_engine;                              \
    case DType::kFloat32: {                                          \
      using type = float;                                            \
      { __VA_ARGS__ }                                                \
    } break;                                                         \
    case DType::kFloat16: {                                          \
      using type = fp16;                                             \
      { __VA_ARGS__ }                                                \
    } break;                                                         \
    case DType::kBFloat16: {                                         \
      using type = bf16;                                             \
      { __VA_ARGS__ }                                                \
    } break;                                                         \
    default:                                                         \
      NVTE_ERROR("Invalid type.");                                   \
  }

778
779
780
781
782
783
784
785
786
787
788
789
// Add a pack_size argument to select the packed type for FP4
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \
  switch (dtype) {                                                             \
    using namespace transformer_engine;                                        \
    case DType::kFloat4E2M1: {                                                 \
      using type = __nv_fp4x2_storage_t;                                       \
      { __VA_ARGS__ }                                                          \
    } break;                                                                   \
    default:                                                                   \
      NVTE_ERROR("Invalid type.");                                             \
  }

Przemek Tredak's avatar
Przemek Tredak committed
790
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \
791
792
793
794
795
796
797
798
799
800
801
802
803
  switch (dtype) {                                               \
    using namespace transformer_engine;                          \
    case DType::kFloat8E5M2: {                                   \
      using type = fp8e5m2;                                      \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    case DType::kFloat8E4M3: {                                   \
      using type = fp8e4m3;                                      \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    default:                                                     \
      NVTE_ERROR("Invalid type.");                               \
  }
Przemek Tredak's avatar
Przemek Tredak committed
804

yuguo's avatar
yuguo committed
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
#define TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(dtype, type, ...)    \
  switch (dtype) {                                               \
    using namespace transformer_engine;                          \
    case DType::kFloat8E5M2: {                                   \
      using type = fp8e5m2;                                      \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    case DType::kFloat8E4M3: {                                   \
      using type = fp8e4m3;                                      \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    case DType::kInt8: {                                         \
      using type = int8;                                         \
      { __VA_ARGS__ }                                            \
    } break;                                                     \
    default:                                                     \
      NVTE_ERROR("Invalid type.");                               \
  }

824
#if FP4_TYPE_SUPPORTED
Przemek Tredak's avatar
Przemek Tredak committed
825
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
826
827
828
829
830
831
832
833
834
835
836
837
838
839
  switch (dtype) {                                             \
    using namespace transformer_engine;                        \
    case DType::kFloat32: {                                    \
      using type = float;                                      \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat16: {                                    \
      using type = fp16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kBFloat16: {                                   \
      using type = bf16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
wenjh's avatar
wenjh committed
840
    case DType::kInt8:                                         \
841
842
843
844
    case DType::kFloat8E5M2:                                   \
    case DType::kFloat8E4M3: {                                 \
      NVTE_ERROR("FP8 type not instantiated for input.");      \
    } break;                                                   \
845
846
847
    case DType::kFloat4E2M1: {                                 \
      NVTE_ERROR("FP4 type not instantiated for input.");      \
    } break;                                                   \
848
849
850
    default:                                                   \
      NVTE_ERROR("Invalid type.");                             \
  }
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
#else
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
  switch (dtype) {                                             \
    using namespace transformer_engine;                        \
    case DType::kFloat32: {                                    \
      using type = float;                                      \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kFloat16: {                                    \
      using type = fp16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
    case DType::kBFloat16: {                                   \
      using type = bf16;                                       \
      { __VA_ARGS__ }                                          \
    } break;                                                   \
wenjh's avatar
wenjh committed
867
    case DType::kInt8:                                         \
868
869
870
871
872
873
874
875
    case DType::kFloat8E5M2:                                   \
    case DType::kFloat8E4M3: {                                 \
      NVTE_ERROR("FP8 type not instantiated for input.");      \
    } break;                                                   \
    default:                                                   \
      NVTE_ERROR("Invalid type.");                             \
  }
#endif
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892

#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \
  switch (dtype) {                                             \
    using namespace transformer_engine;                        \
    case DType::kFloat16: {                                    \
      using type = fp16;                                       \
      __VA_ARGS__;                                             \
      break;                                                   \
    }                                                          \
    case DType::kBFloat16: {                                   \
      using type = bf16;                                       \
      __VA_ARGS__;                                             \
      break;                                                   \
    }                                                          \
    default:                                                   \
      NVTE_ERROR("Invalid type for 16 bit.");                  \
  }
893

894
895
896
897
898
899
900
901
902
903
904
905
906
#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \
  switch (SCALE_DIM) {                                              \
    case 1: {                                                       \
      constexpr size_t DIM = 1;                                     \
      { __VA_ARGS__ }                                               \
    } break;                                                        \
    case 32: {                                                      \
      constexpr size_t DIM = 32;                                    \
      { __VA_ARGS__ }                                               \
    } break;                                                        \
    default: {                                                      \
      NVTE_ERROR("Invalid size of the MX scaling factor.");         \
    }                                                               \
907
  }
908

909
910
911
912
913
914
915
916
917
#define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \
  if (CONDITION) {                                                \
    constexpr bool FLAG = true;                                   \
    { __VA_ARGS__ }                                               \
  } else {                                                        \
    constexpr bool FLAG = false;                                  \
    { __VA_ARGS__ }                                               \
  }

918
////////////////////////////////////////////////////////////////////////////////////////////////////
Przemek Tredak's avatar
Przemek Tredak committed
919

920
inline int log2_ceil(int value) {
921
922
923
  int log2_value = 0;
  while ((1 << log2_value) < value) ++log2_value;
  return log2_value;
924
925
}

926
927
928
929
930
931
932
933
template <size_t B>
inline size_t alignTo(size_t x) {
  size_t r = x % B;
  if (r == 0) return x;

  return x + B - r;
}

Przemek Tredak's avatar
Przemek Tredak committed
934
935
936
937
938
939
940
941
942
template <typename T>
struct is_fp8 : std::false_type {};

template <>
struct is_fp8<fp8e4m3> : std::true_type {};

template <>
struct is_fp8<fp8e5m2> : std::true_type {};

yuguo's avatar
yuguo committed
943
944
945
946
947
948
template <typename T>
struct is_int8 : std::false_type {};

template <>
struct is_int8<int8> : std::true_type {};

949
950
951
952
953
954
955
956
template <typename T>
struct is_fp4 : std::false_type {};

#if FP4_TYPE_SUPPORTED
template <>
struct is_fp4<fp4e2m1> : std::true_type {};
#endif

957
958
959
960
961
962
// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;

963
// Alignment requirements for the Tensor Memory Accelerator (TMA)
964
965
constexpr size_t TMA_GMEM_ALIGNMENT = 16;    // global memory address alignment
constexpr size_t TMA_SHMEM_ALIGNMENT = 128;  // shared memory address alignment
966
967
968
969
970
971
972
973
974

inline bool is_aligned_ptr(const void *ptr, size_t alignment) {
  return reinterpret_cast<uintptr_t>(ptr) % alignment == 0;
}

inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) {
  return is_aligned_ptr(static_cast<const void *>(t.data.dptr), alignment);
}

Przemek Tredak's avatar
Przemek Tredak committed
975
size_t typeToSize(const DType type);
976
977
size_t typeToNumBits(const DType type);

978
void CheckNoopTensor(const Tensor &t, const std::string &name);
979
980
981
void CheckInputTensor(const Tensor &t, const std::string &name);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);

982
983
984
985
986
987
988
/*! \brief Update a tensor's FP8 scale-inverse
 *
 * The FP8 scale-inverse (dequantization scaling factor) is updated
 * with the reciprocal of the FP8 scale (quantization scaling factor).
 */
void update_tensor_scale_inv(Tensor *t, cudaStream_t stream);

989
#define NVTE_API_CALL(api_name) \
990
  transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name);
991

992
993
void checkCuDriverContext(CUstream stream);

yuguo's avatar
yuguo committed
994
#ifndef __HIP_PLATFORM_AMD__
995
996
997
CUtensorMapDataType get_CUtensorMapDataType(DType dtype);

// Set up parameters to create TMA descriptor.
998
999
1000
1001
1002
void create_2D_tensor_map(
    CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY,
    const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX,
    const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits,
    const CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
yuguo's avatar
yuguo committed
1003
#endif
1004
1005
1006

bool is_supported_by_CC_100();

1007
1008
1009
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
                                                        size_t outer_size, size_t inner_size);

1010
1011
Tensor *convertNVTETensor(const NVTETensor tensor);
Tensor *convertNVTETensorCheck(const NVTETensor tensor);
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021

GroupedTensor *convertNVTEGroupedTensor(const NVTEGroupedTensor tensor);
GroupedTensor *convertNVTEGroupedTensorCheck(const NVTEGroupedTensor tensor);

// Helper functions for GroupedTensor validation
void CheckGroupedTensorShapeArrays(const GroupedTensor &t, const std::string &name);
void CheckInputGroupedTensor(const GroupedTensor &t, const std::string &name);
void CheckOutputGroupedTensor(const GroupedTensor &t, const std::string &name,
                              bool allow_empty = false);

Przemek Tredak's avatar
Przemek Tredak committed
1022
1023
1024
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_COMMON_H_