test_common.h 20.9 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#pragma once

#include <memory>
Tim Moon's avatar
Tim Moon committed
10
#include <vector>
11
12
#include <array>
#include <random>
13
#ifndef __HIP_PLATFORM_AMD__
14
#include <cudaTypedefs.h>
15
#endif
16
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
Tim Moon's avatar
Tim Moon committed
17

wenjh's avatar
wenjh committed
18
19
20
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
yuguo's avatar
yuguo committed
21
#include <cuda_runtime_api.h>
Przemek Tredak's avatar
Przemek Tredak committed
22
#include <cuda_bf16.h>
yuguo's avatar
yuguo committed
23
#include <cuda_fp16.h>
Przemek Tredak's avatar
Przemek Tredak committed
24
#include <cuda_fp8.h>
25
26
27
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif
Przemek Tredak's avatar
Przemek Tredak committed
28
#include <cuda_runtime_api.h>
Tim Moon's avatar
Tim Moon committed
29
30
31

#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
32
33
34
35

namespace test {
using namespace transformer_engine;

36
37
38
39
40
41
42
43
44
45
46
47
inline int blockwise_fp8_block_len() {
  const char *env = std::getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN");
  if (env == nullptr || env[0] == '\0') {
    return 128;
  }
  int value;
  std::istringstream iss(env);
  iss >> value;
  NVTE_CHECK(iss, "Invalid environment variable value");
  return value;
}

Przemek Tredak's avatar
Przemek Tredak committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
template <size_t i>
struct BytesToType {};

template <>
struct BytesToType<1> {
  using Type = uint8_t;
};

template <>
struct BytesToType<2> {
  using Type = uint16_t;
};

template <>
struct BytesToType<4> {
  using Type = uint32_t;
};

template <>
struct BytesToType<8> {
  using Type = uint64_t;
};

using byte = uint8_t;
72
using int16 = int16_t;
Przemek Tredak's avatar
Przemek Tredak committed
73
using int32 = int32_t;
cyanguwa's avatar
cyanguwa committed
74
using int64 = int64_t;
Przemek Tredak's avatar
Przemek Tredak committed
75
76
77
78
79
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
80
using fp8e8m0 = uint8_t;
wenjh's avatar
wenjh committed
81
using int8 = int8_t;
82
83
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
84
85
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
86
#endif
Przemek Tredak's avatar
Przemek Tredak committed
87
88

template <typename T>
89
90
91
92
93
94
95
96
97
98
99
100
101
struct BitsNumber;

#if FP4_TYPE_SUPPORTED
template <>
struct BitsNumber<fp4e2m1> {
  static constexpr size_t num_bits = 4;
};
#endif

template <typename T>
struct BitsNumber {
  static constexpr size_t num_bits = 8 * sizeof(T);
};
Przemek Tredak's avatar
Przemek Tredak committed
102
103

template <typename T>
104
105
struct TypeInfo {
#if FP4_TYPE_SUPPORTED
wenjh's avatar
wenjh committed
106
    using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8, fp8e8m0, fp4e2m1>;
107
#else
wenjh's avatar
wenjh committed
108
    using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8, fp8e8m0>;
109
#endif
Przemek Tredak's avatar
Przemek Tredak committed
110
111
112
113
114

    template <typename U, DType current>
    struct Helper {
        constexpr static DType getType() {
            constexpr int i = static_cast<int>(current);
maxiao3's avatar
maxiao3 committed
115
116
117
	    if constexpr (i >= std::tuple_size_v<types>) {
                return DType::kNumTypes;
	    } else if (std::is_same<U, typename std::tuple_element<i, types>::type>::value) {
Przemek Tredak's avatar
Przemek Tredak committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
                return current;
            } else {
                return Helper<U, static_cast<DType>(i + 1)>::getType();
            }
        }
    };

    template <typename U>
    struct Helper<U, DType::kNumTypes> {
        constexpr static DType getType() {
            return DType::kNumTypes;
        }
    };

    template <typename U>
    constexpr static DType getType() {
        return Helper<U, DType::kByte>::getType();
    }

    constexpr static DType dtype = getType<T>();
138
    constexpr static size_t size = BitsNumber<T>::num_bits;;
Przemek Tredak's avatar
Przemek Tredak committed
139
140
141
142
};

class Tensor {
 public:
143
144
145
146
147
148
149
150
151
152
153
154
  Tensor(const std::string& name,
         const NVTEShape &shape, const DType type,
         const bool rowwise = true,
         const bool columnwise = false,
         const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING);

  Tensor(const std::string& name,
         const std::vector<size_t> &shape,
         const DType type,
         const bool rowwise = true,
         const bool columnwise = false,
         const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) :
155
    Tensor(name, nvte_make_shape(shape.data(), shape.size()), type, rowwise, columnwise, mode) {}
Przemek Tredak's avatar
Przemek Tredak committed
156
157
158
159
160
161
162
163
164
165

  Tensor() {}

  Tensor& operator=(const Tensor &other) = delete;
  Tensor(const Tensor &other) = delete;

  Tensor(Tensor &&other) = default;
  Tensor& operator=(Tensor &&other) = default;

  ~Tensor() {
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    void *data_ptr = tensor_.dptr();
    void *scale_inv = tensor_.scale_inv();
    void *columnwise_data_ptr = tensor_.get_columnwise_data().data_ptr;
    void *columnwise_scale_inv = tensor_.get_columnwise_scale_inv().data_ptr;
    if (columnwise_data_ptr == data_ptr) {
      columnwise_data_ptr = nullptr;
    }
    if (columnwise_scale_inv == scale_inv) {
      columnwise_scale_inv = nullptr;
    }
    if (data_ptr != nullptr) {
      cudaFree(data_ptr);
    }
    if (scale_inv != nullptr) {
      cudaFree(scale_inv);
    }
182
    if (columnwise_data_ptr != nullptr) {
183
184
      cudaFree(columnwise_data_ptr);
    }
185
    if (columnwise_scale_inv != nullptr) {
186
      cudaFree(columnwise_scale_inv);
Przemek Tredak's avatar
Przemek Tredak committed
187
188
    }
  }
189

190
  NVTETensor data() const noexcept { return tensor_.data(); }
Przemek Tredak's avatar
Przemek Tredak committed
191

192
  NVTEShape rowwise_shape() const noexcept { return tensor_.get_rowwise_data().shape; }
193

194
  NVTEShape columnwise_shape() const noexcept { return tensor_.get_columnwise_data().shape; }
195
196
197
198
199
200
201
202
203
204
205
206
207

  NVTEShape rowwise_scale_inv_shape() const {
    NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!");
    return tensor_.get_rowwise_scale_inv().shape;
  }

  NVTEShape columnwise_scale_inv_shape() const {
    NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!");
    return tensor_.get_columnwise_scale_inv().shape;
  }

  NVTEScalingMode scaling_mode() const noexcept {
    return tensor_.scaling_mode();
Przemek Tredak's avatar
Przemek Tredak committed
208
209
210
211
212
213
  }

  DType dtype() const noexcept {
    return tensor_.dtype();
  }

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
  void *rowwise_dptr() const {
    NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!");
    return tensor_.get_rowwise_data().data_ptr;
  }

  void *columnwise_dptr() const {
    NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!");
    return tensor_.get_columnwise_data().data_ptr;
  }

  template <typename T>
  T *rowwise_cpu_dptr() const {
    NVTE_CHECK(TypeInfo<T>::dtype == tensor_.dtype(), "Invalid type!");
    NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!");
    return reinterpret_cast<T *>(cpu_data_rowwise_.get());
Przemek Tredak's avatar
Przemek Tredak committed
229
230
231
  }

  template <typename T>
232
  T *columnwise_cpu_dptr() const {
Przemek Tredak's avatar
Przemek Tredak committed
233
    NVTE_CHECK(TypeInfo<T>::dtype == tensor_.dtype(), "Invalid type!");
234
235
    NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!");
    return reinterpret_cast<T *>(cpu_data_columnwise_.get());
Przemek Tredak's avatar
Przemek Tredak committed
236
237
  }

238
239
240
241
242
243
244
245
246
247
248
  float amax() const {
    if(amax_cpu_data_) {
      to_cpu();
      return *amax_cpu_data_;
    } else {
      return 0;
    }
  }

  float scale() const {
    if(scale_cpu_data_) {
249
250
251
      NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
                 || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING),
                 "Invalid scaling_mode!");
252
253
254
255
256
257
258
      to_cpu();
      return *scale_cpu_data_;
    } else {
      return 1;
    }
  }

259
260
261
262
  template <typename T>
  T *rowwise_cpu_scale_inv_ptr(){
    if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){
      NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
263
264
    } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
      NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
265
266
    } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) {
      NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat8E4M3, "Invalid type!");
267
268
269
270
271
272
273
274
275
276
277
    } else {
      NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
    }
    to_cpu();
    return reinterpret_cast<T*>(rowwise_scale_inv_cpu_data_.get());
  }

  template <typename T>
  T *columnwise_cpu_scale_inv_ptr(){
    if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){
      NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
278
279
    } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
      NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
280
281
    } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) {
      NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat8E4M3, "Invalid type!");
282
283
284
285
286
287
288
289
290
291
292
    } else {
      NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
    }
    to_cpu();
    return reinterpret_cast<T*>(columnwise_scale_inv_cpu_data_.get());
  }

  float rowwise_scale_inv(){
    if(rowwise_scale_inv_cpu_data_) {
      float scale_inv = rowwise_cpu_scale_inv_ptr<float>()[0];
      return scale_inv;
293
294
295
296
297
    } else {
      return 1;
    }
  }

298
299
300
301
302
303
304
305
  bool rowwise() const {
    return rowwise_;
  }

  bool columnwise() const {
    return columnwise_;
  }

306
307
308
309
  void set_tensor_amax_nullptr(){
    tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape);
  }

Przemek Tredak's avatar
Przemek Tredak committed
310
311
  void to_cpu() const;
  void from_cpu() const;
312
313
314
  void set_scale(float scale);
  void set_scale_inv(float scale_inv);
  void shareFP8Meta(const Tensor &other);
Przemek Tredak's avatar
Przemek Tredak committed
315

316
317
  std::mt19937& gen() { return gen_; }

Przemek Tredak's avatar
Przemek Tredak committed
318
319
 private:
  TensorWrapper tensor_;
320
321
  std::unique_ptr<unsigned char[]> cpu_data_rowwise_;
  std::unique_ptr<unsigned char[]> cpu_data_columnwise_;
322
323
  std::shared_ptr<float> amax_cpu_data_;
  std::shared_ptr<float> scale_cpu_data_;
324
325
326
327
328
329
330
331
332
333
334
335
336
  std::unique_ptr<unsigned char[]> rowwise_scale_inv_cpu_data_;
  std::unique_ptr<unsigned char[]> columnwise_scale_inv_cpu_data_;
  bool rowwise_;
  bool columnwise_;
  std::string name_;
  std::mt19937 gen_;
};

constexpr uint32_t FP32_EXPONENT_BIAS = 127;
constexpr uint32_t FP32_MANTISSA_BITS = 23;

// [128,4] rowwise and [4,128] colwise alignment requirement
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
337
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
338
constexpr size_t scale_tensor_alignment_Y_colwise = 4;
339
constexpr size_t scale_tensor_alignment_X_colwise = 128;
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419

inline size_t divide_round_up(const size_t N, const size_t M) {
    return (N - 1 + M) / M;
}

inline size_t round_up_to_nearest_multiple(const size_t N, const size_t M) {
    return divide_round_up(N, M) * M;
}

template <typename T>
struct Numeric_Traits {
    static constexpr double minSubnorm = 1.0;
    static constexpr double maxSubnorm = 1.0;
    static constexpr double minNorm    = 1.0;
    static constexpr double maxNorm    = 1.0;
    static constexpr double artifInf   = 1.0;
    static constexpr int maxBiasedExponent = 1;
};

template <>
struct Numeric_Traits<fp8e4m3> {
    static constexpr double minSubnorm = 1.0   / static_cast<double>(1 << 9);   // std::pow(2.0, -9.0);
    static constexpr double maxSubnorm = 0.875 / static_cast<double>(1 << 6);   // std::pow(2.0, -6.0);
    static constexpr double minNorm    = 1.0   / static_cast<double>(1 << 6);   // std::pow(2.0, -6.0);
    static constexpr double maxNorm    = 448.0;
    static constexpr double artifInf   = 10.0 * maxNorm;                        // artificial Infinity
    static constexpr int maxBiasedExponentAsFP32 = 8 + FP32_EXPONENT_BIAS;
    static constexpr int maxUnbiasedExponentAsFP32 = 8;
    static constexpr int maxExpNorm    = 1 << maxUnbiasedExponentAsFP32;
};

template <>
struct Numeric_Traits<fp8e5m2> {
    static constexpr double minSubnorm = 1.0  / static_cast<double>(1 << 16);   // std::pow(2.0, -16.0);
    static constexpr double maxSubnorm = 0.75 / static_cast<double>(1 << 14);   // std::pow(2.0, -14.0);
    static constexpr double minNorm    = 1.0  / static_cast<double>(1 << 14);   // std::pow(2.0, -14.0);
    static constexpr double maxNorm    = 57344.0;
    static constexpr double artifInf   = 10.0 * maxNorm;                        // artificial Infinity
    static constexpr int maxBiasedExponentAsFP32 = 15 + FP32_EXPONENT_BIAS;
    static constexpr int maxUnbiasedExponentAsFP32 = 15;
    static constexpr int maxExpNorm    = 1 << maxUnbiasedExponentAsFP32;
};

template <>
struct Numeric_Traits<fp32> {
    static constexpr double minSubnorm = std::numeric_limits<fp32>::denorm_min();   // std::pow(2.0, -149.0);
    static constexpr double maxSubnorm = std::numeric_limits<fp32>::min()
                                         - std::numeric_limits<fp32>::denorm_min(); // minNormalized - minDenormalized
    static constexpr double minNorm    = std::numeric_limits<fp32>::min();          // std::pow(2.0, -126.0);
    static constexpr double maxNorm    = std::numeric_limits<fp32>::max();          // (1 - pow(2, -24)) * pow(2, 128)
    static constexpr double artifInf   = std::numeric_limits<fp32>::infinity();
    static constexpr int maxBiasedExponentAsFP32 = 255;
    static constexpr int maxUnbiasedExponentAsFP32 = 128;
};

template <typename T>
struct Quantized_Limits {
    static constexpr double ranges[]  = {
        0.0,
        Numeric_Traits<T>::minNorm,
        Numeric_Traits<T>::maxNorm,
        Numeric_Traits<T>::artifInf
    };
    static constexpr inline fp32 max() { return static_cast<fp32>(Numeric_Traits<T>::maxNorm); }
    static constexpr inline fp32 max_reciprocal() { return static_cast<fp32>(1.0 / max()); }
    static constexpr inline fp32 emax() { return static_cast<fp32>(Numeric_Traits<T>::maxExpNorm); }
    static constexpr inline fp32 emax_reciprocal() { return static_cast<fp32>(1.0 / emax()); }
    static constexpr inline int max_norm_biased_exponent() { return Numeric_Traits<T>::maxBiasedExponentAsFP32; }
    static constexpr inline int max_norm_unbiased_exponent() { return Numeric_Traits<T>::maxUnbiasedExponentAsFP32; }
};

// Input data filling cases
// Considering normal and subnormal magnitudes of E4M3 and E5M2 formats
// with nearest to even rounding per OFP8 specification
enum InputsFillCase {
    zero_to_minNorm             = 0,    // [0, min_normal)
    minNorm_to_maxNorm          = 1,    // [min_normal, max_normal)
    maxNorm_to_inf              = 2,    // [max_normal, inf)
    zeros                       = 3,    // {0}
    uniform                     = 4,    // std::uniform_real_distribution<> dis(-2.0, 1.0)
Przemek Tredak's avatar
Przemek Tredak committed
420
421
};

422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
inline fp8e8m0 float_to_e8m0(float val) {
  // TODO: nan/inf needs to be set for any value
  // of nan/inf in input not just amax.
  if (std::isnan(val)) {
    return 0xFF;
  }
  if (std::isinf(val)) {
    return 0xFE;
  }
  if (val == 0.0f) {
    return 0x00;
  }
  uint32_t val_u32 = *reinterpret_cast<uint32_t*>(&val);
  fp8e8m0 exponent = (val_u32 >> FP32_MANTISSA_BITS);
  uint32_t mantissa = val_u32 & 0x7FFFFF;
  // Round up exponent and deal with satfinite.
  if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
    ++exponent;
  }
  return exponent;
}

inline float exp2f_rcp(fp8e8m0 biased_exp) {
445
446
447
448
449
450
  if (biased_exp == 0) {
    return 1.0f;
  }
  int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS;   // 127 - (biased_exp - 127)
  float fp32_val = *reinterpret_cast<float*>(&int_val);
  return fp32_val;
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
}

inline float identity(const float x) { return x; }
inline float gelu(const float x)     { return x * (0.5f + 0.5f * tanhf(x * (0.79788456f + 0.03567741f * x * x))); }
inline float dgelu(const float x) {
    const float tanh_out = tanhf(0.79788456f * x * (1 + 0.044715f * x * x));
    return 0.5f * x * ((1 - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x))
           + 0.5f * (1 + tanh_out);
}
inline float sigmoid(const float x)  { return 1 / (1 + expf(-x)); }
inline float dsigmoid(const float x) { return sigmoid(x) * (1 - sigmoid(x)); }
inline float qgelu(const float x)    { return x * sigmoid(1.702f * x); }
inline float dqgelu(const float x)   { return 1.702f * x * dsigmoid(1.702f * x) + sigmoid(1.702f * x); }
inline float relu(const float x)     { return fmaxf(0, x); }
inline float drelu(const float x)    { return x > 0 ? 1 : 0; }
inline float silu(const float x)     { return x * sigmoid(x); }
inline float dsilu(const float x)    { return x * dsigmoid(x) + sigmoid(x); }
inline float srelu(const float x)    { return x > 0 ? x * x : 0; }
inline float dsrelu(const float x)   { return fmaxf(0, 2 * x); }

471
size_t typeToNumBits(DType type);
Przemek Tredak's avatar
Przemek Tredak committed
472
size_t product(const NVTEShape &shape);
473
size_t product(const std::vector<size_t> &shape);
474
size_t bytes(const NVTEShape& shape, const DType type);
475
476
477

size_t first_dimension(const std::vector<size_t> &shape);
size_t last_dimension(const std::vector<size_t> &shape);
Przemek Tredak's avatar
Przemek Tredak committed
478
479
480
481

bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2);

void compareResults(const std::string &name, const Tensor &test, const void *ref,
482
483
                    bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true,
                    const size_t tolerable_mismatches_limit = 0);
484
485
void compareResults(const std::string &name, const float test, const float ref,
                    double atol = 1e-5, double rtol = 1e-8);
486
487
void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
                    size_t N, float mismatch_rate_tol = 0.);
488
489
490
491
492
493
494
495
template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
                             const size_t row_blocks, const size_t col_blocks, const size_t stride,
                             size_t& mismatches_num,
                             const size_t scale_diff_abs_tolerance = 0,
                             const double abs_tolerable_mismatches_limit = 0,
                             const double rel_tolerable_mismatches_limit = 0);

496
497
498

std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols,
                                            const size_t block_size_rows, const size_t block_size_cols);
Przemek Tredak's avatar
Przemek Tredak committed
499
500
501

std::pair<double, double> getTolerances(const DType type);

502
void fillUniform(Tensor *t);
503
504
505
506

template <typename InputEncoding>
void fillCase(Tensor *t, const InputsFillCase fill_case);

507
void setRandomScale(Tensor *t);
508
void setRandomScaleInv(Tensor *t);
Przemek Tredak's avatar
Przemek Tredak committed
509
510
511
512

constexpr int THREADS_PER_WARP = 32;

const std::string &typeName(DType type);
513
const std::string& caseName(InputsFillCase type);
Przemek Tredak's avatar
Przemek Tredak committed
514
515
516

extern std::vector<DType> all_fp_types;

517
bool isFp8Type(DType type);
518
bool isFp4Type(DType type);
519

520
int32_t getDeviceComputeCapability();
521
constexpr int32_t hopperComputeCapability = 90;
522
523
constexpr int32_t blackwellComputeCapability = 100;

Przemek Tredak's avatar
Przemek Tredak committed
524
525
}  // namespace test

526
527
528
529
530
531
532
533
534
535
#if FP4_TYPE_SUPPORTED
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
  case DType::kFloat4E2M1: {              \
    using type = fp4e2m1;                 \
    { __VA_ARGS__ }                       \
  } break;
#else
#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing
#endif

Przemek Tredak's avatar
Przemek Tredak committed
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
    switch (dtype) { \
        using namespace transformer_engine; \
        case DType::kByte: \
            { \
                using type = byte; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kInt32: \
            { \
                using type = int32; \
                {__VA_ARGS__} \
            } \
        break; \
cyanguwa's avatar
cyanguwa committed
551
552
553
554
555
556
        case DType::kInt64: \
            { \
                using type = int64; \
                {__VA_ARGS__} \
            } \
        break; \
Przemek Tredak's avatar
Przemek Tredak committed
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
        case DType::kFloat32: \
            { \
                using type = float; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kFloat16: \
            { \
                using type = fp16; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kBFloat16: \
            { \
                using type = bf16; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kFloat8E4M3: \
            { \
                using type = fp8e4m3; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kFloat8E5M2: \
            { \
                using type = fp8e5m2; \
                {__VA_ARGS__} \
            } \
        break; \
587
588
589
590
591
592
593
        case DType::kFloat8E8M0: \
            { \
                using type = fp8e8m0; \
                {__VA_ARGS__} \
            } \
        break; \
        SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
Przemek Tredak's avatar
Przemek Tredak committed
594
        default: \
595
            printf("dtype: %d\n", static_cast<int>(dtype)); \
596
            NVTE_ERROR("Invalid type."); \
Przemek Tredak's avatar
Przemek Tredak committed
597
    }
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614

#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
    switch (dtype) { \
        using namespace transformer_engine; \
        case DType::kFloat8E4M3: \
            { \
                using type = fp8e4m3; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kFloat8E5M2: \
            { \
                using type = fp8e5m2; \
                {__VA_ARGS__} \
            } \
        break; \
        default: \
615
            NVTE_ERROR("Invalid type."); \
616
617
618
619
620
621
622
    }

#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \
    switch (dtype) { \
        using namespace transformer_engine; \
        SWITCH_FP4_HANDLE(type, __VA_ARGS__) \
        default: \
623
            NVTE_ERROR("Invalid type."); \
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
    }

#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
    switch (dtype) { \
        using namespace transformer_engine; \
        case DType::kFloat32: \
            { \
                using type = float; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kFloat16: \
            { \
                using type = fp16; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kBFloat16: \
            { \
                using type = bf16; \
                {__VA_ARGS__} \
            } \
        break; \
        default: \
648
            NVTE_ERROR("Invalid type."); \
649
    }