/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #pragma once #include #include #include #include #ifndef __HIP_PLATFORM_AMD__ #include #endif #define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080) #ifdef __HIP_PLATFORM_AMD__ #include #endif #include #include #include #include #if FP4_TYPE_SUPPORTED #include #endif #include #include #include "util/logging.h" namespace test { using namespace transformer_engine; inline int blockwise_fp8_block_len() { const char *env = std::getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN"); if (env == nullptr || env[0] == '\0') { return 128; } int value; std::istringstream iss(env); iss >> value; NVTE_CHECK(iss, "Invalid environment variable value"); return value; } template struct BytesToType {}; template <> struct BytesToType<1> { using Type = uint8_t; }; template <> struct BytesToType<2> { using Type = uint16_t; }; template <> struct BytesToType<4> { using Type = uint32_t; }; template <> struct BytesToType<8> { using Type = uint64_t; }; using byte = uint8_t; using int16 = int16_t; using int32 = int32_t; using int64 = int64_t; using fp32 = float; using fp16 = half; using bf16 = nv_bfloat16; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; using fp8e8m0 = uint8_t; using int8 = int8_t; #if FP4_TYPE_SUPPORTED using fp4e2m1 = __nv_fp4_e2m1; using fp4e2m1x2 = __nv_fp4x2_e2m1; using fp4e2m1x4 = __nv_fp4x4_e2m1; #endif template struct BitsNumber; #if FP4_TYPE_SUPPORTED template <> struct BitsNumber { static constexpr size_t num_bits = 4; }; #endif template struct BitsNumber { static constexpr size_t num_bits = 8 * sizeof(T); }; template struct TypeInfo { #if FP4_TYPE_SUPPORTED using types = std::tuple; #else using types = std::tuple; #endif template struct Helper { constexpr static DType getType() { constexpr int i = static_cast(current); if constexpr (i >= std::tuple_size_v) { return DType::kNumTypes; } else if (std::is_same::type>::value) { return current; } else { return Helper(i + 1)>::getType(); } } }; template struct Helper { constexpr static DType getType() { return DType::kNumTypes; } }; template constexpr static DType getType() { return Helper::getType(); } constexpr static DType dtype = getType(); constexpr static size_t size = BitsNumber::num_bits;; }; class Tensor { public: Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise = true, const bool columnwise = false, const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); Tensor(const std::string& name, const std::vector &shape, const DType type, const bool rowwise = true, const bool columnwise = false, const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : Tensor(name, nvte_make_shape(shape.data(), shape.size()), type, rowwise, columnwise, mode) {} Tensor() {} Tensor& operator=(const Tensor &other) = delete; Tensor(const Tensor &other) = delete; Tensor(Tensor &&other) = default; Tensor& operator=(Tensor &&other) = default; ~Tensor() { void *data_ptr = tensor_.dptr(); void *scale_inv = tensor_.scale_inv(); void *columnwise_data_ptr = tensor_.get_columnwise_data().data_ptr; void *columnwise_scale_inv = tensor_.get_columnwise_scale_inv().data_ptr; if (columnwise_data_ptr == data_ptr) { columnwise_data_ptr = nullptr; } if (columnwise_scale_inv == scale_inv) { columnwise_scale_inv = nullptr; } if (data_ptr != nullptr) { cudaFree(data_ptr); } if (scale_inv != nullptr) { cudaFree(scale_inv); } if (columnwise_data_ptr != nullptr) { cudaFree(columnwise_data_ptr); } if (columnwise_scale_inv != nullptr) { cudaFree(columnwise_scale_inv); } } NVTETensor data() const noexcept { return tensor_.data(); } NVTEShape rowwise_shape() const noexcept { return tensor_.get_rowwise_data().shape; } NVTEShape columnwise_shape() const noexcept { return tensor_.get_columnwise_data().shape; } NVTEShape rowwise_scale_inv_shape() const { NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); return tensor_.get_rowwise_scale_inv().shape; } NVTEShape columnwise_scale_inv_shape() const { NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); return tensor_.get_columnwise_scale_inv().shape; } NVTEScalingMode scaling_mode() const noexcept { return tensor_.scaling_mode(); } DType dtype() const noexcept { return tensor_.dtype(); } void *rowwise_dptr() const { NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); return tensor_.get_rowwise_data().data_ptr; } void *columnwise_dptr() const { NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); return tensor_.get_columnwise_data().data_ptr; } template T *rowwise_cpu_dptr() const { NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); return reinterpret_cast(cpu_data_rowwise_.get()); } template T *columnwise_cpu_dptr() const { NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); return reinterpret_cast(cpu_data_columnwise_.get()); } float amax() const { if(amax_cpu_data_) { to_cpu(); return *amax_cpu_data_; } else { return 0; } } float scale() const { if(scale_cpu_data_) { NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING), "Invalid scaling_mode!"); to_cpu(); return *scale_cpu_data_; } else { return 1; } } template T *rowwise_cpu_scale_inv_ptr(){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } to_cpu(); return reinterpret_cast(rowwise_scale_inv_cpu_data_.get()); } template T *columnwise_cpu_scale_inv_ptr(){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } to_cpu(); return reinterpret_cast(columnwise_scale_inv_cpu_data_.get()); } float rowwise_scale_inv(){ if(rowwise_scale_inv_cpu_data_) { float scale_inv = rowwise_cpu_scale_inv_ptr()[0]; return scale_inv; } else { return 1; } } bool rowwise() const { return rowwise_; } bool columnwise() const { return columnwise_; } void set_tensor_amax_nullptr(){ tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); } void to_cpu() const; void from_cpu() const; void set_scale(float scale); void set_scale_inv(float scale_inv); void shareFP8Meta(const Tensor &other); std::mt19937& gen() { return gen_; } private: TensorWrapper tensor_; std::unique_ptr cpu_data_rowwise_; std::unique_ptr cpu_data_columnwise_; std::shared_ptr amax_cpu_data_; std::shared_ptr scale_cpu_data_; std::unique_ptr rowwise_scale_inv_cpu_data_; std::unique_ptr columnwise_scale_inv_cpu_data_; bool rowwise_; bool columnwise_; std::string name_; std::mt19937 gen_; }; constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_MANTISSA_BITS = 23; // [128,4] rowwise and [4,128] colwise alignment requirement constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_colwise = 4; constexpr size_t scale_tensor_alignment_X_colwise = 128; inline size_t divide_round_up(const size_t N, const size_t M) { return (N - 1 + M) / M; } inline size_t round_up_to_nearest_multiple(const size_t N, const size_t M) { return divide_round_up(N, M) * M; } template struct Numeric_Traits { static constexpr double minSubnorm = 1.0; static constexpr double maxSubnorm = 1.0; static constexpr double minNorm = 1.0; static constexpr double maxNorm = 1.0; static constexpr double artifInf = 1.0; static constexpr int maxBiasedExponent = 1; }; template <> struct Numeric_Traits { static constexpr double minSubnorm = 1.0 / static_cast(1 << 9); // std::pow(2.0, -9.0); static constexpr double maxSubnorm = 0.875 / static_cast(1 << 6); // std::pow(2.0, -6.0); static constexpr double minNorm = 1.0 / static_cast(1 << 6); // std::pow(2.0, -6.0); static constexpr double maxNorm = 448.0; static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity static constexpr int maxBiasedExponentAsFP32 = 8 + FP32_EXPONENT_BIAS; static constexpr int maxUnbiasedExponentAsFP32 = 8; static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32; }; template <> struct Numeric_Traits { static constexpr double minSubnorm = 1.0 / static_cast(1 << 16); // std::pow(2.0, -16.0); static constexpr double maxSubnorm = 0.75 / static_cast(1 << 14); // std::pow(2.0, -14.0); static constexpr double minNorm = 1.0 / static_cast(1 << 14); // std::pow(2.0, -14.0); static constexpr double maxNorm = 57344.0; static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity static constexpr int maxBiasedExponentAsFP32 = 15 + FP32_EXPONENT_BIAS; static constexpr int maxUnbiasedExponentAsFP32 = 15; static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32; }; template <> struct Numeric_Traits { static constexpr double minSubnorm = std::numeric_limits::denorm_min(); // std::pow(2.0, -149.0); static constexpr double maxSubnorm = std::numeric_limits::min() - std::numeric_limits::denorm_min(); // minNormalized - minDenormalized static constexpr double minNorm = std::numeric_limits::min(); // std::pow(2.0, -126.0); static constexpr double maxNorm = std::numeric_limits::max(); // (1 - pow(2, -24)) * pow(2, 128) static constexpr double artifInf = std::numeric_limits::infinity(); static constexpr int maxBiasedExponentAsFP32 = 255; static constexpr int maxUnbiasedExponentAsFP32 = 128; }; template struct Quantized_Limits { static constexpr double ranges[] = { 0.0, Numeric_Traits::minNorm, Numeric_Traits::maxNorm, Numeric_Traits::artifInf }; static constexpr inline fp32 max() { return static_cast(Numeric_Traits::maxNorm); } static constexpr inline fp32 max_reciprocal() { return static_cast(1.0 / max()); } static constexpr inline fp32 emax() { return static_cast(Numeric_Traits::maxExpNorm); } static constexpr inline fp32 emax_reciprocal() { return static_cast(1.0 / emax()); } static constexpr inline int max_norm_biased_exponent() { return Numeric_Traits::maxBiasedExponentAsFP32; } static constexpr inline int max_norm_unbiased_exponent() { return Numeric_Traits::maxUnbiasedExponentAsFP32; } }; // Input data filling cases // Considering normal and subnormal magnitudes of E4M3 and E5M2 formats // with nearest to even rounding per OFP8 specification enum InputsFillCase { zero_to_minNorm = 0, // [0, min_normal) minNorm_to_maxNorm = 1, // [min_normal, max_normal) maxNorm_to_inf = 2, // [max_normal, inf) zeros = 3, // {0} uniform = 4, // std::uniform_real_distribution<> dis(-2.0, 1.0) }; inline fp8e8m0 float_to_e8m0(float val) { // TODO: nan/inf needs to be set for any value // of nan/inf in input not just amax. if (std::isnan(val)) { return 0xFF; } if (std::isinf(val)) { return 0xFE; } if (val == 0.0f) { return 0x00; } uint32_t val_u32 = *reinterpret_cast(&val); fp8e8m0 exponent = (val_u32 >> FP32_MANTISSA_BITS); uint32_t mantissa = val_u32 & 0x7FFFFF; // Round up exponent and deal with satfinite. if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { ++exponent; } return exponent; } inline float exp2f_rcp(fp8e8m0 biased_exp) { if (biased_exp == 0) { return 1.0f; } int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127) float fp32_val = *reinterpret_cast(&int_val); return fp32_val; } inline float identity(const float x) { return x; } inline float gelu(const float x) { return x * (0.5f + 0.5f * tanhf(x * (0.79788456f + 0.03567741f * x * x))); } inline float dgelu(const float x) { const float tanh_out = tanhf(0.79788456f * x * (1 + 0.044715f * x * x)); return 0.5f * x * ((1 - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) + 0.5f * (1 + tanh_out); } inline float sigmoid(const float x) { return 1 / (1 + expf(-x)); } inline float dsigmoid(const float x) { return sigmoid(x) * (1 - sigmoid(x)); } inline float qgelu(const float x) { return x * sigmoid(1.702f * x); } inline float dqgelu(const float x) { return 1.702f * x * dsigmoid(1.702f * x) + sigmoid(1.702f * x); } inline float relu(const float x) { return fmaxf(0, x); } inline float drelu(const float x) { return x > 0 ? 1 : 0; } inline float silu(const float x) { return x * sigmoid(x); } inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); } inline float srelu(const float x) { return x > 0 ? x * x : 0; } inline float dsrelu(const float x) { return fmaxf(0, 2 * x); } size_t typeToNumBits(DType type); size_t product(const NVTEShape &shape); size_t product(const std::vector &shape); size_t bytes(const NVTEShape& shape, const DType type); size_t first_dimension(const std::vector &shape); size_t last_dimension(const std::vector &shape); bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); void compareResults(const std::string &name, const Tensor &test, const void *ref, bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, const size_t tolerable_mismatches_limit = 0); void compareResults(const std::string &name, const float test, const float ref, double atol = 1e-5, double rtol = 1e-8); void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t N, float mismatch_rate_tol = 0.); template void compare_scaling_factors(const std::string &name, const T *test, const T *ref, const size_t row_blocks, const size_t col_blocks, const size_t stride, size_t& mismatches_num, const size_t scale_diff_abs_tolerance = 0, const double abs_tolerable_mismatches_limit = 0, const double rel_tolerable_mismatches_limit = 0); std::array get_scale_tensor_dims(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols); std::pair getTolerances(const DType type); void fillUniform(Tensor *t); template void fillCase(Tensor *t, const InputsFillCase fill_case); void setRandomScale(Tensor *t); void setRandomScaleInv(Tensor *t); constexpr int THREADS_PER_WARP = 32; const std::string &typeName(DType type); const std::string& caseName(InputsFillCase type); extern std::vector all_fp_types; bool isFp8Type(DType type); bool isFp4Type(DType type); int32_t getDeviceComputeCapability(); constexpr int32_t hopperComputeCapability = 90; constexpr int32_t blackwellComputeCapability = 100; } // namespace test #if FP4_TYPE_SUPPORTED #define SWITCH_FP4_TYPE_HANDLE(type, ...) \ case DType::kFloat4E2M1: { \ using type = fp4e2m1; \ { __VA_ARGS__ } \ } break; #else #define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing #endif #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kByte: \ { \ using type = byte; \ {__VA_ARGS__} \ } \ break; \ case DType::kInt32: \ { \ using type = int32; \ {__VA_ARGS__} \ } \ break; \ case DType::kInt64: \ { \ using type = int64; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat32: \ { \ using type = float; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat16: \ { \ using type = fp16; \ {__VA_ARGS__} \ } \ break; \ case DType::kBFloat16: \ { \ using type = bf16; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E4M3: \ { \ using type = fp8e4m3; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E5M2: \ { \ using type = fp8e5m2; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E8M0: \ { \ using type = fp8e8m0; \ {__VA_ARGS__} \ } \ break; \ SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ default: \ printf("dtype: %d\n", static_cast(dtype)); \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat8E4M3: \ { \ using type = fp8e4m3; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E5M2: \ { \ using type = fp8e5m2; \ {__VA_ARGS__} \ } \ break; \ default: \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ SWITCH_FP4_HANDLE(type, __VA_ARGS__) \ default: \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat32: \ { \ using type = float; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat16: \ { \ using type = fp16; \ {__VA_ARGS__} \ } \ break; \ case DType::kBFloat16: \ { \ using type = bf16; \ {__VA_ARGS__} \ } \ break; \ default: \ NVTE_ERROR("Invalid type."); \ }