/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_ #define TRANSFORMER_ENGINE_COMMON_COMMON_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "./nvtx.h" #include "./util/cuda_driver.h" #include "./util/logging.h" namespace transformer_engine { std::string to_string(const DType type); std::string to_string(const NVTEScalingMode &mode); inline bool is_tensor_scaling(const NVTEScalingMode &mode) { return mode == NVTE_DELAYED_TENSOR_SCALING; } inline bool is_block_scaling(const NVTEScalingMode &mode) { return !is_tensor_scaling(mode); } inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { return mode == NVTE_DELAYED_TENSOR_SCALING; } inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", end, " in a vector with ", shape.size(), " entries"); size_t ret = 1; for (size_t i = begin; i < end; ++i) { ret *= shape[i]; } return ret; } inline size_t product(const std::vector &shape) { size_t ret = 1; for (const auto &elem : shape) { ret *= elem; } return ret; } struct SimpleTensor { void *dptr; std::vector shape; DType dtype; SimpleTensor(void *dptr, const std::vector &shape, DType dtype) : dptr(dptr), shape(shape), dtype(dtype) {} SimpleTensor(const NVTEBasicTensor &tensor) // NOLINT : dptr(tensor.data_ptr), shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim), dtype(static_cast(tensor.dtype)) {} SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} operator NVTEBasicTensor() const { const NVTEShape shape = {this->shape.data(), this->shape.size()}; return {dptr, static_cast(dtype), shape}; } int numel() const { size_t acc = 1; for (const auto &dim : shape) { acc *= dim; } return acc; } }; struct Tensor { SimpleTensor data; SimpleTensor columnwise_data; SimpleTensor amax; SimpleTensor scale; SimpleTensor scale_inv; SimpleTensor columnwise_scale_inv; private: // Used as an allocation for nvte_tensor_shape // if the shape has to be inferred from columnwise data. mutable std::vector rowwise_shape_cache; public: NVTEScalingMode scaling_mode; Tensor() : data(), columnwise_data(), amax(nullptr, {1}, DType::kFloat32), scale(nullptr, {1}, DType::kFloat32), scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32), scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} int numel() const { size_t acc = 1; for (const auto dim : shape()) { acc *= dim; } return acc; } bool has_data() const noexcept { return data.dptr != nullptr; } // Check for size (not just pointer) for 0-dim or no token cases. bool has_columnwise_data() const noexcept { return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0; } DType dtype() const { if (has_data()) return data.dtype; if (has_columnwise_data()) return columnwise_data.dtype; // Fallback, used e.g. in workspace return data.dtype; } std::vector shape() const { /* Note: We sometimes experience spurious compiler errors * (-Wstringop-overflow) from this function. It appears that GCC * has some bugs with std::vector (see * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). */ switch (scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: if (!has_data() && has_columnwise_data()) { std::vector ret; if (!columnwise_data.shape.empty()) { for (size_t i = 1; i < columnwise_data.shape.size(); i++) { ret.push_back(columnwise_data.shape[i]); } ret.push_back(columnwise_data.shape.front()); } return ret; } else { return data.shape; } break; case NVTE_MXFP8_1D_SCALING: if (!has_data() && has_columnwise_data()) { return columnwise_data.shape; } else { return data.shape; } break; case NVTE_BLOCK_SCALING_1D: case NVTE_BLOCK_SCALING_2D: { if (!has_data() && has_columnwise_data()) { std::vector shape; size_t ndim = columnwise_data.shape.size(); shape.reserve(ndim); for (size_t i = 0; i + 1 < ndim; ++i) { shape.push_back(columnwise_data.shape[i + 1]); } if (ndim > 0) { shape.push_back(columnwise_data.shape[0]); } return shape; } else { // NOTE: We may have removed the data pointer from // data by setting usage. In that case, we return // the non-null shape. It is our best guess at the most // recent shape. return data.shape; } break; } default: NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); return {}; } } const std::vector &rowwise_shape_ref() const { auto shape_queried = shape(); // This method is primarily designed for nvte_shape. // An unfortunate consequence of unconditionally assigning // values to rowwise_shape_cache without a check is that // repeated calls to rowwise_shape_ref are likely to // invalidate the data pointers from previous calls. // If the shape has changed, then invalidating is necessary // in at least some cases, but we want to keep the data // valid otherwise. if (rowwise_shape_cache != shape_queried) { rowwise_shape_cache = std::move(shape_queried); } return rowwise_shape_cache; } /*! Matrix height after tensor is flattened to 2D * * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted * as a (D1*D2*...*D(n-1), Dn) matrix. */ size_t flat_first_dim() const { const auto &full_shape = shape(); size_t ret = 1; if (!full_shape.empty()) { for (size_t i = 0; i < full_shape.size() - 1; i++) { ret *= full_shape[i]; } } return ret; } /*! Matrix width after tensor is flattened to 2D * * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted * as a (D1*D2*...*D(n-1), Dn) matrix. */ size_t flat_last_dim() const { const auto &full_shape = shape(); if (full_shape.empty()) { return 1; } else { return full_shape.back(); } } }; struct QuantizationConfig { bool force_pow_2_scales = false; float amax_epsilon = 0.0f; NVTETensor noop_tensor = nullptr; static constexpr size_t attr_sizes[] = { sizeof(bool), // force_pow_2_scales sizeof(float), // amax_epsilon sizeof(NVTETensor) // noop_tensor }; }; template constexpr T DIVUP(const T &x, const T &y) { return (((x) + ((y)-1)) / (y)); } using byte = uint8_t; using int32 = int32_t; using int64 = int64_t; using fp32 = float; using fp16 = half; using bf16 = nv_bfloat16; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; #if CUDA_VERSION >= 12080 using fp8e8m0 = __nv_fp8_e8m0; #endif using e8m0_t = uint8_t; namespace detail { template constexpr inline const char *type_name() noexcept; #define TRANSFORMER_ENGINE_TYPE_NAME(T) \ template <> \ inline constexpr const char *type_name() noexcept { \ return #T; \ } TRANSFORMER_ENGINE_TYPE_NAME(uint8_t) TRANSFORMER_ENGINE_TYPE_NAME(int32_t) TRANSFORMER_ENGINE_TYPE_NAME(int64_t) TRANSFORMER_ENGINE_TYPE_NAME(float) TRANSFORMER_ENGINE_TYPE_NAME(half) TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2) #if CUDA_VERSION >= 12080 TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0) #endif #undef TRANSFORMER_ENGINE_TYPE_NAME template struct TypeExtrema; template <> struct TypeExtrema { static constexpr float max = 448.0f; }; template <> struct TypeExtrema { static constexpr float max = 57344.0f; }; template <> struct TypeExtrema { // Hex float format of 1.(7 bits of 1) * 2 ^ 127 static constexpr float max = 0x1.FEp127; }; template <> struct TypeExtrema { // Hex float format of 1.(10 bits of 1) * 2 ^ 15 static constexpr float max = 0x1.FFCp15; }; template struct TypeExtrema { static constexpr float max = std::numeric_limits::max(); }; } // namespace detail template struct TypeInfo { using types = std::tuple; template struct Helper { constexpr static DType getType() { constexpr int i = static_cast(current); if (std::is_same::type>::value) { return current; } else { return Helper(i + 1)>::getType(); } } }; template struct Helper { constexpr static DType getType() { return DType::kNumTypes; } }; template constexpr static DType getType() { return Helper::getType(); } constexpr static DType dtype = getType(); constexpr static size_t size = sizeof(T); constexpr static float max_finite_value = detail::TypeExtrema::max; constexpr static const char *name = detail::type_name(); }; #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kByte: { \ using type = unsigned char; \ { __VA_ARGS__ } \ } break; \ case DType::kInt32: { \ using type = int32_t; \ { __VA_ARGS__ } \ } break; \ case DType::kInt64: { \ using type = int64_t; \ { __VA_ARGS__ } \ } break; \ case DType::kFloat32: { \ using type = float; \ { __VA_ARGS__ } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ { __VA_ARGS__ } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ { __VA_ARGS__ } \ } break; \ case DType::kFloat8E4M3: { \ using type = fp8e4m3; \ { __VA_ARGS__ } \ } break; \ case DType::kFloat8E5M2: { \ using type = fp8e5m2; \ { __VA_ARGS__ } \ } break; \ case DType::kFloat8E8M0: { \ using type = byte; \ { __VA_ARGS__ } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat32: { \ using type = float; \ { __VA_ARGS__ } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ { __VA_ARGS__ } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ { __VA_ARGS__ } \ } break; \ case DType::kFloat8E5M2: { \ using type = fp8e5m2; \ { __VA_ARGS__ } \ } break; \ case DType::kFloat8E4M3: { \ using type = fp8e4m3; \ { __VA_ARGS__ } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat32: { \ using type = float; \ { __VA_ARGS__ } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ { __VA_ARGS__ } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ { __VA_ARGS__ } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat8E5M2: { \ using type = fp8e5m2; \ { __VA_ARGS__ } \ } break; \ case DType::kFloat8E4M3: { \ using type = fp8e4m3; \ { __VA_ARGS__ } \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat32: { \ using type = float; \ { __VA_ARGS__ } \ } break; \ case DType::kFloat16: { \ using type = fp16; \ { __VA_ARGS__ } \ } break; \ case DType::kBFloat16: { \ using type = bf16; \ { __VA_ARGS__ } \ } break; \ case DType::kFloat8E5M2: \ case DType::kFloat8E4M3: { \ NVTE_ERROR("FP8 type not instantiated for input."); \ } break; \ default: \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat16: { \ using type = fp16; \ __VA_ARGS__; \ break; \ } \ case DType::kBFloat16: { \ using type = bf16; \ __VA_ARGS__; \ break; \ } \ default: \ NVTE_ERROR("Invalid type for 16 bit."); \ } #define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ switch (SCALE_DIM) { \ case 1: { \ constexpr size_t DIM = 1; \ { __VA_ARGS__ } \ } break; \ case 32: { \ constexpr size_t DIM = 32; \ { __VA_ARGS__ } \ } break; \ default: { \ NVTE_ERROR("Invalid size of the MX scaling factor."); \ } \ } #define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ if (CONDITION) { \ constexpr bool FLAG = true; \ { __VA_ARGS__ } \ } else { \ constexpr bool FLAG = false; \ { __VA_ARGS__ } \ } //////////////////////////////////////////////////////////////////////////////////////////////////// inline int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; return log2_value; } template inline size_t alignTo(size_t x) { size_t r = x % B; if (r == 0) return x; return x + B - r; } template struct is_fp8 : std::false_type {}; template <> struct is_fp8 : std::true_type {}; template <> struct is_fp8 : std::true_type {}; // [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_Y_colwise = 4; // Alignment requirements for the Tensor Memory Accelerator (TMA) constexpr int TMA_gmem_alignment = 16; // global memory address alignment inline bool is_aligned_ptr(const void *ptr, size_t alignment) { return reinterpret_cast(ptr) % alignment == 0; } inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) { return is_aligned_ptr(static_cast(t.data.dptr), alignment); } size_t typeToSize(const DType type); void CheckNoopTensor(const Tensor &t, const std::string &name); void CheckInputTensor(const Tensor &t, const std::string &name); void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); bool is_fp8_dtype(const DType t); /*! \brief Update a tensor's FP8 scale-inverse * * The FP8 scale-inverse (dequantization scaling factor) is updated * with the reciprocal of the FP8 scale (quantization scaling factor). */ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream); #define NVTE_API_CALL(api_name) \ transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name); void checkCuDriverContext(CUstream stream); CUtensorMapDataType get_CUtensorMapDataType(DType dtype); // Set up parameters to create TMA descriptor. void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_size); bool is_supported_by_CC_100(); } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_