Unverified Commit 8a7ab3dd authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] NVFP4 support in TE/JAX (#2254)


Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent e99be1b6
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <iostream>
#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/hadamard_transform.h"
#include "transformer_engine/recipe.h"
#include "transformer_engine/transformer_engine.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
Error_Type RHTAmaxCalculationFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type amax_buf,
Result_Type post_rht_amax_buf,
int64_t rht_matrix_random_sign_mask_t, bool produce_regular_amax,
int64_t flatten_axis) {
NVTE_CHECK(input_buf.untyped_data() != nullptr,
"Input must be provided for RHT Amax calculation");
NVTE_CHECK(convert_ffi_datatype_to_te_dtype(input_buf.element_type()) == DType::kBFloat16,
"Input must be of type bfloat16 for RHT Amax calculation");
NVTE_CHECK(flatten_axis > 0 && flatten_axis < static_cast<int64_t>(input_buf.dimensions().size()),
"Flatten axis is out of bounds");
TensorWrapper input_tensor(input_buf.untyped_data(),
std::vector<size_t>{product(input_buf.dimensions(), 0, flatten_axis),
product(input_buf.dimensions(), flatten_axis,
input_buf.dimensions().size())},
convert_ffi_datatype_to_te_dtype(input_buf.element_type()));
float *amax_out = nullptr;
if (produce_regular_amax) {
amax_out = reinterpret_cast<float *>(amax_buf->untyped_data());
NVTE_CHECK(amax_out != nullptr, "Amax output must be provided for RHT Amax calculation");
NVTE_CHECK(convert_ffi_datatype_to_te_dtype(amax_buf->element_type()) == DType::kFloat32,
"Amax output must be of type float32 for RHT Amax calculation");
NVTE_CHECK(amax_buf->dimensions().size() == 1 && amax_buf->dimensions()[0] == 1,
"Amax output must be a single float for RHT Amax calculation");
}
float *post_rht_amax_out = reinterpret_cast<float *>(post_rht_amax_buf->untyped_data());
NVTE_CHECK(post_rht_amax_out != nullptr,
"Post-RHT Amax output must be provided for RHT Amax calculation");
NVTE_CHECK(convert_ffi_datatype_to_te_dtype(post_rht_amax_buf->element_type()) == DType::kFloat32,
"Post-RHT Amax output must be of type float32 for RHT Amax calculation");
NVTE_CHECK(post_rht_amax_buf->dimensions().size() == 1 && post_rht_amax_buf->dimensions()[0] == 1,
"Post-RHT Amax output must be a single float for RHT Amax calculation");
TensorWrapper out_tensor{};
out_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1});
out_tensor.set_columnwise_amax(post_rht_amax_out, DType::kFloat32, std::vector<size_t>{1});
// Zero'ing of amaxes is handled by TE common inside nvte_hadamard_transform_amax
nvte_hadamard_transform_amax(input_tensor.data(), out_tensor.data(),
0, // Regular amax for rowwise does not apply RHT so mask is 0
rht_matrix_random_sign_mask_t, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
RHTAmaxCalculationHandler, RHTAmaxCalculationFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // post_rht_amax
.Attr<int64_t>("rht_matrix_random_sign_mask_t") // rht_matrix_random_sign_mask_t
.Attr<bool>("produce_regular_amax") // produce_regular_amax
.Attr<int64_t>("flatten_axis"), // flatten_axis
FFI_CudaGraph_Traits);
Error_Type RHTAmaxCalculationInitializeFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type amax_buf, Result_Type post_rht_amax_buf,
int64_t rht_matrix_random_sign_mask_t,
bool produce_regular_amax, int64_t flatten_axis) {
return wrapInStreamCapture(std::function(RHTAmaxCalculationFFI), stream, input_buf, amax_buf,
post_rht_amax_buf, rht_matrix_random_sign_mask_t, produce_regular_amax,
flatten_axis);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
RHTAmaxCalculationInitializeHandler, RHTAmaxCalculationInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // post_rht_amax
.Attr<int64_t>("rht_matrix_random_sign_mask_t") // rht_matrix_random_sign_mask_t
.Attr<bool>("produce_regular_amax") // produce_regular_amax
.Attr<int64_t>("flatten_axis")); // flatten_axis
} // namespace jax
} // namespace transformer_engine
...@@ -41,6 +41,9 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { ...@@ -41,6 +41,9 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
case xla::ffi::DataType::F8E8M0FNU: case xla::ffi::DataType::F8E8M0FNU:
return DType::kFloat8E8M0; return DType::kFloat8E8M0;
break; break;
case xla::ffi::DataType::F4E2M1FN:
return DType::kFloat4E2M1;
break;
default: default:
auto type_num = static_cast<XLA_FFI_DataType>(type); auto type_num = static_cast<XLA_FFI_DataType>(type);
NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
......
...@@ -102,6 +102,8 @@ inline static size_t te_dtype_bytes(const DType& type) { ...@@ -102,6 +102,8 @@ inline static size_t te_dtype_bytes(const DType& type) {
return 1; return 1;
case DType::kFloat8E8M0: case DType::kFloat8E8M0:
return 1; return 1;
case DType::kFloat4E2M1:
return 1;
default: default:
NVTE_ERROR("Unsupported DType: ", static_cast<int>(type)); NVTE_ERROR("Unsupported DType: ", static_cast<int>(type));
} }
......
...@@ -51,7 +51,8 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand( ...@@ -51,7 +51,8 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
// Set scaling factor for quantized tensors // Set scaling factor for quantized tensors
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands."); NVTE_CHECK(is_nvfp4_scaling(scaling_mode) || typeToSize(input_dtype) == 1,
"Quantized GEMM requires 4-bit or 8-bit operands.");
NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM.");
std::vector<size_t> scale_shape = {1}; std::vector<size_t> scale_shape = {1};
...@@ -74,7 +75,8 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand( ...@@ -74,7 +75,8 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type rhs_scale_inv, Buffer_Type bias,
Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad, Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta,
Result_Type output, Result_Type bias_grad,
Result_Type pre_gelu_out, Result_Type workspace, Result_Type pre_gelu_out, Result_Type workspace,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed, int64_t rhs_axis_boundary, bool lhs_transposed,
...@@ -119,6 +121,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, ...@@ -119,6 +121,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI,
.Arg<Buffer_Type>() // rhs_scale_inv .Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias .Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input .Arg<Buffer_Type>() // gelu_input
.Arg<Buffer_Type>() // alpha
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad .Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out .Ret<Buffer_Type>() // pre_gelu_out
...@@ -136,11 +140,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, ...@@ -136,11 +140,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI,
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type bias_grad,
Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator, bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad,
JAXX_Collective_Op collective_op) { bool use_split_accumulator, JAXX_Collective_Op collective_op) {
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
...@@ -192,10 +196,31 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -192,10 +196,31 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
std::vector<size_t> workspace_shape = {static_cast<size_t>(workspace->element_count()) - 256}; std::vector<size_t> workspace_shape = {static_cast<size_t>(workspace->element_count()) - 256};
auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte); auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte);
// Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0); auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
float one = 1.;
float zero = 0.;
// alpha, beta
float *alpha_ptr = &one, *beta_ptr = &zero;
if (is_nvfp4_scaling(scaling_mode)) {
NVTE_CHECK(alpha.element_count() == 1 &&
convert_ffi_datatype_to_te_dtype(alpha.element_type()) == DType::kFloat32);
alpha_ptr = reinterpret_cast<float *>(alpha.untyped_data());
NVTE_CHECK(beta.element_count() == 1 &&
convert_ffi_datatype_to_te_dtype(beta.element_type()) == DType::kFloat32);
beta_ptr = reinterpret_cast<float *>(beta.untyped_data());
}
// Construct GEMM config
transformer_engine::MatmulConfigWrapper config;
config.set_use_split_accumulator(use_split_accumulator);
config.set_sm_count(num_math_sm);
if (fuse_bias) config.set_bias_tensor(bias_.data());
if (fuse_gelu) {
config.set_with_gelu_epilogue(true);
config.set_epilogue_aux_tensor(pre_gelu_.data());
}
if (collective_op == JAXX_Collective_Op::NONE) { if (collective_op == JAXX_Collective_Op::NONE) {
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(), NVTE_CHECK(out_.numel() == output->element_count(),
...@@ -205,9 +230,10 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -205,9 +230,10 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size, NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size,
", out_shape[1]=", out_shape[1]); ", out_shape[1]=", out_shape[1]);
nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
rhs_transposed, lhs_transposed, grad, workspace_.data(), false, nvte_cublas_gemm_v2(rhs_transposed /*transa*/, lhs_transposed /*transb*/, alpha_ptr,
use_split_accumulator, num_math_sm, stream); rhs_.data() /*A*/, lhs_.data() /*B*/, beta_ptr, out_.data() /*C*/,
out_.data() /*D*/, workspace_.data(), config, stream);
} else { } else {
std::vector<size_t> buffer_shape{0, 0}; std::vector<size_t> buffer_shape{0, 0};
DType buffer_dtype = out_dtype; DType buffer_dtype = out_dtype;
...@@ -268,6 +294,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, ...@@ -268,6 +294,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Arg<Buffer_Type>() // rhs_scale_inv .Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias .Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input .Arg<Buffer_Type>() // gelu_input
.Arg<Buffer_Type>() // alpha
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad .Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out .Ret<Buffer_Type>() // pre_gelu_out
...@@ -599,9 +627,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -599,9 +627,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
// point to swizzled scale_inv data (store on workspace, only used for GEMM). // point to swizzled scale_inv data (store on workspace, only used for GEMM).
// Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers
auto lhs_sinv_shape_i = auto lhs_sinv_shape_i =
get_mxfp8_scale_shape(lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise);
auto rhs_sinv_shape_i = auto rhs_sinv_shape_i =
get_mxfp8_scale_shape(rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); get_block_scale_shape(scaling_mode, rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise);
lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1];
rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1];
if (lhs_use_colwise) { if (lhs_use_colwise) {
......
...@@ -26,11 +26,21 @@ std::vector<size_t> Shape::to_vector() const { ...@@ -26,11 +26,21 @@ std::vector<size_t> Shape::to_vector() const {
return shape; return shape;
} }
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise) { std::vector<size_t> get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N,
auto block_x = is_colwise ? MXFP8_BLOCK_SIZE.y : MXFP8_BLOCK_SIZE.x; bool is_colwise) {
auto block_y = is_colwise ? MXFP8_BLOCK_SIZE.x : MXFP8_BLOCK_SIZE.y; auto block_size = BLOCK_SIZE(1, 1);
auto alignment_x = is_colwise ? MXFP8_ALIGNMENT.y : MXFP8_ALIGNMENT.x; if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
auto alignment_y = is_colwise ? MXFP8_ALIGNMENT.x : MXFP8_ALIGNMENT.y; block_size = MXFP8_BLOCK_SIZE;
} else if (scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING) {
block_size = NVFP4_BLOCK_SIZE;
} else {
NVTE_ERROR("Unsupported scaling_mode = ", static_cast<int>(scaling_mode));
}
auto block_x = is_colwise ? block_size.y : block_size.x;
auto block_y = is_colwise ? block_size.x : block_size.y;
auto alignment_x = is_colwise ? BLOCK_SCALE_ALIGNMENT.y : BLOCK_SCALE_ALIGNMENT.x;
auto alignment_y = is_colwise ? BLOCK_SCALE_ALIGNMENT.x : BLOCK_SCALE_ALIGNMENT.y;
NVTE_CHECK(M % block_x == 0, "M must be divisble by %zu (got %zu)", block_x, M); NVTE_CHECK(M % block_x == 0, "M must be divisble by %zu (got %zu)", block_x, M);
NVTE_CHECK(N % block_y == 0, "N must be divisble by %zu (got %zu)", block_y, N); NVTE_CHECK(N % block_y == 0, "N must be divisble by %zu (got %zu)", block_y, N);
......
...@@ -45,6 +45,8 @@ enum class JAXX_Scaling_Mode : int64_t { ...@@ -45,6 +45,8 @@ enum class JAXX_Scaling_Mode : int64_t {
DELAYED_TENSOR_SCALING = 1, DELAYED_TENSOR_SCALING = 1,
MXFP8_1D_SCALING = 2, MXFP8_1D_SCALING = 2,
CURRENT_TENSOR_SCALING = 3, CURRENT_TENSOR_SCALING = 3,
NVFP4_1D_SCALING = 4,
NVFP4_2D_SCALING = 5,
}; };
inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) { inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) {
...@@ -56,6 +58,11 @@ inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) { ...@@ -56,6 +58,11 @@ inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING); return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING);
} }
inline bool is_nvfp4_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING);
}
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
switch (mode) { switch (mode) {
case JAXX_Scaling_Mode::NO_SCALING: case JAXX_Scaling_Mode::NO_SCALING:
...@@ -70,22 +77,32 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { ...@@ -70,22 +77,32 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING: case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING:
return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
break; break;
case JAXX_Scaling_Mode::NVFP4_1D_SCALING:
return NVTEScalingMode::NVTE_NVFP4_1D_SCALING;
break;
case JAXX_Scaling_Mode::NVFP4_2D_SCALING:
// TE common uses the same enum value for 1D and 2D fp4 scaling and instead differentiates them via quant_config.nvfp4_2d_quantization
return NVTEScalingMode::NVTE_NVFP4_1D_SCALING;
break;
default: default:
NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode)); NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode));
break; break;
} }
} }
constexpr struct BlockSize { struct BLOCK_SIZE {
size_t x; size_t x;
size_t y; size_t y;
} MXFP8_BLOCK_SIZE{1, 32}; constexpr BLOCK_SIZE(int _x, int _y) : x(_x), y(_y) {}
constexpr struct Alignment { };
size_t x;
size_t y; constexpr BLOCK_SIZE MXFP8_BLOCK_SIZE{1, 32};
} MXFP8_ALIGNMENT{128, 4}; constexpr BLOCK_SIZE NVFP4_BLOCK_SIZE{1, 16};
constexpr BLOCK_SIZE BLOCK_SCALE_ALIGNMENT{128, 4};
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); std::vector<size_t> get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N,
bool is_colwise);
template <typename T, typename... Rest> template <typename T, typename... Rest>
void hash_combine(int64_t &seed, const T &v, Rest... rest) { void hash_combine(int64_t &seed, const T &v, Rest... rest) {
......
...@@ -76,6 +76,11 @@ pybind11::dict Registrations() { ...@@ -76,6 +76,11 @@ pybind11::dict Registrations() {
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler));
// Amax
dict["te_rht_amax_ffi"] = pybind11::dict(
pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler));
return dict; return dict;
} }
...@@ -106,7 +111,9 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -106,7 +111,9 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("kFloat16", DType::kFloat16) .value("kFloat16", DType::kFloat16)
.value("kBFloat16", DType::kBFloat16) .value("kBFloat16", DType::kBFloat16)
.value("kFloat8E4M3", DType::kFloat8E4M3) .value("kFloat8E4M3", DType::kFloat8E4M3)
.value("kFloat8E5M2", DType::kFloat8E5M2); .value("kFloat8E5M2", DType::kFloat8E5M2)
.value("kFloat8E8M0", DType::kFloat8E8M0)
.value("kFloat4E2M1", DType::kFloat4E2M1);
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local())
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
...@@ -165,6 +172,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -165,6 +172,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
.value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING) .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING)
.value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) .value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING)
.value("NVFP4_1D_SCALING", JAXX_Scaling_Mode::NVFP4_1D_SCALING)
.value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING)
.export_values(); .export_values();
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout", pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
......
...@@ -5,8 +5,11 @@ ...@@ -5,8 +5,11 @@
************************************************************************/ ************************************************************************/
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream>
#include "../extensions.h" #include "../extensions.h"
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "transformer_engine/hadamard_transform.h"
#include "transformer_engine/recipe.h" #include "transformer_engine/recipe.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
...@@ -15,7 +18,7 @@ namespace transformer_engine { ...@@ -15,7 +18,7 @@ namespace transformer_engine {
namespace jax { namespace jax {
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype, DType scale_dtype,
JAXX_Scaling_Mode scaling_mode, JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout) { QuantizeLayout q_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
...@@ -30,16 +33,22 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -30,16 +33,22 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
// this function. We pass a dummy pointer as a workaround. // this function. We pass a dummy pointer as a workaround.
int temp = 0; int temp = 0;
bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING;
auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype); auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto scale_shape = std::vector<size_t>{1};
// Only the pointers will be checked for scale_inv, thus the shapes do not matter // Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) { if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) {
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape); output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) { if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32, if (is_nvfp4)
std::vector<size_t>{1}); scale_shape = get_block_scale_shape(scaling_mode, batch_size, hidden_size, false);
output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), scale_dtype,
scale_shape);
} }
} }
...@@ -49,13 +58,16 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -49,13 +58,16 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape); output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter // Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) { if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32, if (is_nvfp4)
std::vector<size_t>{1}); scale_shape =
get_block_scale_shape(scaling_mode, hidden_size, batch_size, false); //Transpose
output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), scale_dtype,
scale_shape);
} }
} }
if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4) {
output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32, output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1}); std::vector<size_t>{1});
output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32, output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
...@@ -72,17 +84,20 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -72,17 +84,20 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
} }
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Buffer_Type amax_buf, Result_Type output_buf, Buffer_Type amax_buf, Buffer_Type sr_rng_state,
Result_Type output_trans_buf, Result_Type scale_inv_buf, Buffer_Type post_rht_amax_buf, Buffer_Type rht_matrix_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type output_buf, Result_Type output_trans_buf,
Result_Type dbias_buf, Result_Type workspace_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, Result_Type updated_amax_buf, Result_Type dbias_buf,
bool is_dbias, int64_t flatten_axis) { Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis,
bool stochastic_rounding, bool use_rht) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization."); NVTE_CHECK(is_fp8_dtype(out_dtype) || is_fp4_dtype(out_dtype),
"Output datatype must be FP8 or FP4 for quantization.");
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
...@@ -112,25 +127,27 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -112,25 +127,27 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
bool const is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING;
NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4.");
NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling");
if (quantize_layout == QuantizeLayout::ROWWISE || if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
if (is_tensor_scaling) { if (is_tensor_scaling) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data()); float *amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
float *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax == updated_amax && amax != nullptr, NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
"amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv( output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(), scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
std::vector<size_t>{1});
} else { } else {
output_tensor.set_rowwise_scale_inv( output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(), scale_inv_buf->untyped_data(),
...@@ -140,13 +157,76 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -140,13 +157,76 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
scale_inv_buf->dimensions().size())}); scale_inv_buf->dimensions().size())});
} }
} }
if (is_nvfp4) {
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
NVTE_CHECK(amax != nullptr, "amax must be provided for NVFP4");
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
QuantizationConfigWrapper quant_config{};
if (scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING) {
quant_config.set_nvfp4_2d_quantization(true);
}
// Stochastic rounding
quant_config.set_stochastic_rounding(stochastic_rounding);
TensorWrapper sr_rng_state_tensor(sr_rng_state.untyped_data(), std::vector<size_t>{2},
DType::kInt64);
if (stochastic_rounding) {
NVTE_CHECK(sr_rng_state.size_bytes() == 2 * sizeof(uint64_t),
"rng_state must be of type int64[2]");
NVTE_CHECK(sr_rng_state.untyped_data() != nullptr, "rng_state must be provided for SR");
quant_config.set_rng_state(sr_rng_state_tensor.data());
} }
if (quantize_layout == QuantizeLayout::COLWISE || if (quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) if (is_nvfp4 && use_rht) {
? output_trans_shape if (quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
: output_shape; // Do regular rowwise quantization without RHT
nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream);
}
TensorWrapper out_transpose(get_nvte_scaling_mode(scaling_mode));
// nvte_hadamard_transform_cast_fusion_columnwise expects the colwise data to be populated in the rowwise buffers on TensorWrapper
out_transpose.set_rowwise_data(output_trans, out_dtype, output_trans_shape);
auto const colwise_flatten_axis = output_trans_buf->dimensions().size() - flatten_axis;
out_transpose.set_rowwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0, colwise_flatten_axis),
product(colwise_scale_inv_buf->dimensions(), colwise_flatten_axis,
colwise_scale_inv_buf->dimensions().size())});
float *post_rht_amax = reinterpret_cast<float *>(post_rht_amax_buf.untyped_data());
NVTE_CHECK(post_rht_amax != nullptr, "Post-RHT colwise amax must be provided for NVFP4");
out_transpose.set_amax(post_rht_amax, DType::kFloat32, std::vector<size_t>{1});
bool const eligible_for_rht_cast_fusion =
input_tensor.dtype() == DType::kBFloat16 && m % 64 == 0 && n % 128 == 0;
NVTE_CHECK(eligible_for_rht_cast_fusion, "RHT cast fusion conditions not met");
NVTE_CHECK(
convert_ffi_datatype_to_te_dtype(rht_matrix_buf.element_type()) == DType::kBFloat16,
"RHT matrix must be bf16");
NVTE_CHECK(rht_matrix_buf.dimensions().size() == 2 && rht_matrix_buf.dimensions()[0] == 16 &&
rht_matrix_buf.dimensions()[1] == 16,
"RHT matrix must be 16x16");
TensorWrapper rht_matrix_tensor(rht_matrix_buf.untyped_data(), std::vector<size_t>{16, 16},
DType::kBFloat16);
nvte_hadamard_transform_cast_fusion_columnwise(input_tensor.data(), out_transpose.data(),
rht_matrix_tensor.data(), quant_config,
stream);
return ffi_with_cuda_error_check();
}
bool const is_colwise_transposed =
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4;
auto &tmp_shape = is_colwise_transposed ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape); output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf; auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf;
...@@ -156,26 +236,30 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -156,26 +236,30 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1}); std::vector<size_t>{1});
} else { } else {
auto colwise_flatten_axis = flatten_axis;
if (is_colwise_transposed) {
// convert flatten_axis from N layout to T layout
colwise_flatten_axis = tmp_buf->dimensions().size() - flatten_axis;
}
output_tensor.set_columnwise_scale_inv( output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{ std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis), product(tmp_buf->dimensions(), 0, colwise_flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())}); product(tmp_buf->dimensions(), colwise_flatten_axis, tmp_buf->dimensions().size())});
}
} }
if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
} }
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
if (is_dbias) { if (is_dbias) {
NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NVFP4_2D_SCALING,
"DBias quantization is not supported for NVFP4_2D_SCALING as fused dbias API cannot "
"take quant_config as input.");
nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
} else { } else {
nvte_quantize(input_tensor.data(), output_tensor.data(), stream); nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream);
} }
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -186,6 +270,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, ...@@ -186,6 +270,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Arg<Buffer_Type>() // input .Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax .Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // sr_rng_state
.Arg<Buffer_Type>() // colwise amax
.Arg<Buffer_Type>() // rht matrix
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output .Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
...@@ -196,7 +283,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, ...@@ -196,7 +283,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout") .Attr<int64_t>("q_layout")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis"), .Attr<int64_t>("flatten_axis")
.Attr<bool>("stochastic_rounding")
.Attr<bool>("use_rht"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
...@@ -346,7 +435,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -346,7 +435,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
sinv_size = 1; sinv_size = 1;
} else { } else {
const bool is_colwise = false; const bool is_colwise = false;
auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise); auto sinv_shape_i = get_block_scale_shape(scaling_mode, m_i, n, is_colwise);
out_i.set_rowwise_scale_inv(static_cast<void *>(sinv_ptr), sinv_dtype, sinv_shape_i); out_i.set_rowwise_scale_inv(static_cast<void *>(sinv_ptr), sinv_dtype, sinv_shape_i);
sinv_size = product(sinv_shape_i); sinv_size = product(sinv_shape_i);
} }
...@@ -365,7 +454,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -365,7 +454,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
colwise_sinv_size = 1; colwise_sinv_size = 1;
} else { } else {
const bool is_colwise = true; const bool is_colwise = true;
auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise); auto sinv_shape_i = get_block_scale_shape(scaling_mode, m_i, n, is_colwise);
out_i.set_columnwise_scale_inv(static_cast<void *>(colwise_sinv_ptr), sinv_dtype, out_i.set_columnwise_scale_inv(static_cast<void *>(colwise_sinv_ptr), sinv_dtype,
sinv_shape_i); sinv_shape_i);
colwise_sinv_size = product(sinv_shape_i); colwise_sinv_size = product(sinv_shape_i);
......
...@@ -16,7 +16,7 @@ import jax ...@@ -16,7 +16,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope from .cpp_extensions.amax import AmaxScope
from .quantize import ( from .quantize import (
ScaledTensorFactory, ScaledTensorFactory,
ScalingMode, ScalingMode,
......
...@@ -15,7 +15,6 @@ from jax import lax ...@@ -15,7 +15,6 @@ from jax import lax
from jax import random as jax_random from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from transformer_engine.common import recipe
from ..dense import dense from ..dense import dense
...@@ -35,10 +34,9 @@ from ..cpp_extensions import ( ...@@ -35,10 +34,9 @@ from ..cpp_extensions import (
from ..quantize import ( from ..quantize import (
QuantizerFactory, QuantizerFactory,
get_quantize_config, get_quantize_config,
QuantizeMeta,
QuantizeMetaSet, QuantizeMetaSet,
ScalingMode,
TensorSource, TensorSource,
get_quantize_config_with_recipe,
) )
PRNGKey = Any PRNGKey = Any
...@@ -353,40 +351,32 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -353,40 +351,32 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
Generate a set of FP8 meta for a GEMM. Generate a set of FP8 meta for a GEMM.
""" """
def generate_quantize_meta(quantizer_name: str):
collection_name = ( collection_name = (
variable_collection variable_collection
if variable_collection is not None if variable_collection is not None
else get_quantize_config().COLLECTION_NAME else get_quantize_config().COLLECTION_NAME
) )
scale = self.variable(
collection_name, if fp8_recipe is None:
f"{quantizer_name}{postfix}_scale", quantize_config = get_quantize_config()
jnp.ones,
(1,),
jnp.float32,
).value
amax_history = self.variable(
collection_name,
f"{quantizer_name}{postfix}_amax_history",
jnp.zeros,
(get_quantize_config().AMAX_HISTORY_LEN,),
jnp.float32,
).value
return QuantizeMeta(scale=scale, amax_history=amax_history)
if get_quantize_config().get_scaling_mode(
TensorSource.X
) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling):
x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad")
quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
kwargs = {"quantize_meta_set": quantize_meta_set}
else: else:
kwargs = {} quantize_config = get_quantize_config_with_recipe(fp8_recipe)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs) x_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.X, "x"
)
kernel_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.KERNEL, "kernel"
)
grad_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.DGRAD, "grad"
)
quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
quantizer_set = QuantizerFactory.create_set(
fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set
)
return quantizer_set return quantizer_set
......
...@@ -16,7 +16,7 @@ import jax ...@@ -16,7 +16,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope from .cpp_extensions.amax import AmaxScope
from .quantize import ( from .quantize import (
QuantizerSet, QuantizerSet,
......
...@@ -21,7 +21,7 @@ import jax.numpy as jnp ...@@ -21,7 +21,7 @@ import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope from .cpp_extensions.amax import AmaxScope
from .layernorm import canonicalize_norm_type from .layernorm import canonicalize_norm_type
from .quantize import ( from .quantize import (
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
......
...@@ -14,5 +14,6 @@ from .quantizer import * ...@@ -14,5 +14,6 @@ from .quantizer import *
from .dequantizer import * from .dequantizer import *
from .scaling_modes import * from .scaling_modes import *
from .metadata import * from .metadata import *
from .hadamard import *
from .helper import * from .helper import *
from .device_utils import * from .device_utils import *
...@@ -15,6 +15,8 @@ import jax ...@@ -15,6 +15,8 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .hadamard import apply_rht, should_use_rht
__all__ = ["ScalingModeToDequantizerMap"] __all__ = ["ScalingModeToDequantizerMap"]
...@@ -119,7 +121,7 @@ class BlockScaleDequantizer(Dequantizer): ...@@ -119,7 +121,7 @@ class BlockScaleDequantizer(Dequantizer):
0 < flatten_axis < len(data_shape) 0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaling_mode.get_scale_shape( scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
) )
data = data.reshape( data = data.reshape(
...@@ -161,10 +163,99 @@ class BlockScaleDequantizer(Dequantizer): ...@@ -161,10 +163,99 @@ class BlockScaleDequantizer(Dequantizer):
) )
class NVFP4Dequantizer(Dequantizer):
"""NVFP4 Dequantizer Class.
This class provides static methods for dequantizing tensors that have been
quantized using NVFP4 scaling modes.
"""
@staticmethod
def _dequantize_func(data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis):
"""Dequantize a tensor using block scaling.
Args:
data: The quantized tensor data
scale_inv: The inverse scaling factors
amax: The maximum absolute value of the tensor
dq_dtype: The data type for dequantized values
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
Returns:
The dequantized tensor
"""
DATA_DTYPE_MAX = jnp.finfo(data.dtype).max.astype(jnp.float32)
SCALE_DTYPE_MAX = jnp.finfo(scale_inv.dtype).max.astype(jnp.float32)
tensor_scale_inv = amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX)
data = data.astype(jnp.float32)
scale_inv = scale_inv.astype(jnp.float32) * tensor_scale_inv
data_layout = "T" if is_colwise else "N"
data_shape = data.shape
flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaling_mode.get_scale_shape(
data_shape,
data_layout=data_layout,
is_colwise=is_colwise,
is_padded=False,
# expect the flatten_axis wrt the N layout
flatten_axis=flatten_axis if data_layout == "N" else len(data_shape) - flatten_axis,
broadcast_2d_scale_shape_to_1d=True,
)
data = data.reshape(
*data_shape[: flatten_axis - 1],
scale_shape[flatten_axis - 1],
int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*data_shape[flatten_axis:-1],
scale_shape[-1],
int(data_shape[-1] / scale_shape[-1]),
)
scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1))
out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape)
# Apply inverse of RHT if needed
use_rht = should_use_rht(scaling_mode, is_colwise=is_colwise)
if use_rht:
out = apply_rht(out, inverse=True)
return out
@staticmethod
def dequantize(scaled_tensor):
"""Dequantize a tensor using block scaling.
Args:
scaled_tensor: The quantized tensor to dequantize
Returns:
The dequantized tensor
"""
return NVFP4Dequantizer._dequantize_func(
scaled_tensor.data,
scaled_tensor.scale_inv,
scaled_tensor.amax,
scaled_tensor.dq_dtype,
scaled_tensor.scaling_mode,
scaled_tensor.is_colwise,
scaled_tensor.flatten_axis,
)
ScalingModeToDequantizerMap = { ScalingModeToDequantizerMap = {
ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer,
ScalingMode.NVFP4_1D_SCALING: NVFP4Dequantizer,
ScalingMode.NVFP4_2D_SCALING: NVFP4Dequantizer,
ScalingMode.NO_SCALING: NoopDequantizer, ScalingMode.NO_SCALING: NoopDequantizer,
} }
...@@ -210,13 +301,13 @@ def _grouped_dequantize(grouped_scaled_tensor): ...@@ -210,13 +301,13 @@ def _grouped_dequantize(grouped_scaled_tensor):
) )
padded_scale_shape_i = scaling_mode.get_scale_shape( padded_scale_shape_i = scaling_mode.get_scale_shape(
data_shape_i, data_shape_i,
grouped_scaled_tensor.is_colwise, is_colwise=grouped_scaled_tensor.is_colwise,
is_padded=True, is_padded=True,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
unpadded_scale_shape_i = scaling_mode.get_scale_shape( unpadded_scale_shape_i = scaling_mode.get_scale_shape(
data_shape_i, data_shape_i,
grouped_scaled_tensor.is_colwise, is_colwise=grouped_scaled_tensor.is_colwise,
is_padded=False, is_padded=False,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Randomized Hadamard Transform (RHT) utilities for JAX."""
import jax.numpy as jnp
from .scaling_modes import ScalingMode
def should_use_rht(scaling_mode, is_colwise=None, q_layout=None) -> bool:
"""Determine if RHT (Randomized Hadamard Transform) should be used.
Args:
scaling_mode: The scaling mode of the tensor.
is_colwise: Whether the tensor is column-wise. Only one of is_colwise or q_layout should be provided.
q_layout: The quantization layout of the tensor. Only one of is_colwise or q_layout should be provided.
Returns:
bool: True if RHT should be used, False otherwise.
"""
# Delayed import to avoid circular dependencies
from .quantizer import QuantizeLayout
assert (is_colwise is None) != (
q_layout is None
), "Exactly one of is_colwise or q_layout must be provided."
if q_layout is not None:
is_colwise = q_layout in {QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE}
return scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise
def get_wgrad_sign_vector() -> list[int]:
"""Get a fixed sign vector for the RHT used in NVFP4 weight gradient quantization."""
return [1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1]
def get_sign_from_vector(vector: list[int]) -> int:
"""Convert a sign vector to a bitmask integer."""
mask = 0
for i, v in enumerate(vector):
mask |= (v == -1) << i
return mask
def apply_rht(x: jnp.ndarray, inverse=False) -> jnp.ndarray:
"""Apply the Randomized Hadamard Transform (RHT) to the input tensor."""
h = get_rht_matrix()
block_size = 16
if inverse:
h = jnp.linalg.inv(h.astype(jnp.float32)).astype(jnp.bfloat16)
# TODO(jberchtold): These reshapes will break partitioning, fixme
return (x.reshape(-1, block_size) @ h).reshape(x.shape)
def get_rht_matrix() -> jnp.ndarray:
"""Get the Randomized Hadamard Transform (RHT) matrix used in NVFP4 weight gradient quantization.
Returns:
A (16, 16) bfloat16 matrix representing the RHT. This matrix is pre-multiplied by the random sign mask.
"""
import scipy
block_size = 16
h = jnp.array(scipy.linalg.hadamard(block_size))
# Apply the random sign mask
s = jnp.array(get_wgrad_sign_vector(), dtype=jnp.int32)
h = jnp.diag(s) @ h
return (h / jnp.sqrt(block_size)).astype(jnp.bfloat16)
This diff is collapsed.
...@@ -9,23 +9,29 @@ This module provides classes for managing quantization metadata, including ...@@ -9,23 +9,29 @@ This module provides classes for managing quantization metadata, including
scale factors and amax history for different tensor types. scale factors and amax history for different tensor types.
""" """
from dataclasses import dataclass from dataclasses import dataclass
import jax.numpy as jnp
__all__ = ["QuantizeMeta", "QuantizeMetaSet"] __all__ = ["QuantizeMeta", "QuantizeMetaSet"]
@dataclass
class QuantizeMeta: class QuantizeMeta:
"""Metadata for quantization parameters. """Metadata for quantization parameters.
Attributes: For Delayed Scaling recipe:
scale: The scaling factor for quantization scale: The scaling factor for quantization
amax_history: History of maximum absolute values amax_history: History of maximum absolute values
For NVFP4 recipe with Stochastic Rounding:
sr_rng_state: The state of the stochastic rounding RNG
""" """
scale: jnp.ndarray def __init__(self, **kwargs):
amax_history: jnp.ndarray self._kwargs = kwargs
def get_kwargs_dictionary(self):
"""Get the metadata as a dictionary."""
return self._kwargs
@dataclass @dataclass
......
...@@ -201,13 +201,32 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): ...@@ -201,13 +201,32 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
else: else:
unpadded_scale_shape = self.scaling_mode.get_scale_shape( unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.data.shape,
data_layout=self.data_layout,
is_colwise=self.is_colwise, is_colwise=self.is_colwise,
is_padded=False, is_padded=False,
flatten_axis=self.flatten_axis, # expect the flatten_axis wrt the N layout
flatten_axis=(
self.flatten_axis
if self.data_layout == "N"
else self.data.ndim - self.flatten_axis
),
) )
assert self.scale_inv.shape == unpadded_scale_shape, ( unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape(
"Unpadded inverse scale factor has wrong shape, expected" self.data.shape,
f" {unpadded_scale_shape} but got {self.scale_inv.shape}." data_layout=self.data_layout,
is_colwise=self.is_colwise,
is_padded=False,
# expect the flatten_axis wrt the N layout
flatten_axis=(
self.flatten_axis
if self.data_layout == "N"
else self.data.ndim - self.flatten_axis
),
broadcast_2d_scale_shape_to_1d=True,
)
assert self.scale_inv.shape in (unpadded_scale_shape, unpadded_scale_shape_broadcast), (
f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or"
f" {unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}."
) )
def tree_flatten(self): def tree_flatten(self):
...@@ -583,6 +602,7 @@ class ScaledTensorFactory: ...@@ -583,6 +602,7 @@ class ScaledTensorFactory:
colwise_data, colwise_data,
colwise_scale_inv, colwise_scale_inv,
amax=None, amax=None,
colwise_amax=None,
scaling_mode=ScalingMode.NO_SCALING, scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=jnp.bfloat16, dq_dtype=jnp.bfloat16,
data_layout="NN", data_layout="NN",
...@@ -612,6 +632,8 @@ class ScaledTensorFactory: ...@@ -612,6 +632,8 @@ class ScaledTensorFactory:
""" """
if amax is None: if amax is None:
amax = jnp.empty((1,), dtype=jnp.float32) amax = jnp.empty((1,), dtype=jnp.float32)
if colwise_amax is None:
colwise_amax = amax
assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}" assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
rowwise_tensor = ScaledTensorFactory.create_1x( rowwise_tensor = ScaledTensorFactory.create_1x(
...@@ -630,10 +652,10 @@ class ScaledTensorFactory: ...@@ -630,10 +652,10 @@ class ScaledTensorFactory:
colwise_tensor = ScaledTensorFactory.create_1x( colwise_tensor = ScaledTensorFactory.create_1x(
colwise_data, colwise_data,
colwise_scale_inv, colwise_scale_inv,
amax, colwise_amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
is_colwise=True, is_colwise=True, # TODO(Phuong): set this correctly
data_layout=data_layout[1], data_layout=data_layout[1],
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
group_sizes=group_sizes, group_sizes=group_sizes,
...@@ -649,6 +671,7 @@ class ScaledTensorFactory: ...@@ -649,6 +671,7 @@ class ScaledTensorFactory:
colwise_data: jnp.ndarray, colwise_data: jnp.ndarray,
colwise_scale_inv: jnp.ndarray, colwise_scale_inv: jnp.ndarray,
amax=None, amax=None,
colwise_amax=None,
scaling_mode: ScalingMode = ScalingMode.NO_SCALING, scaling_mode: ScalingMode = ScalingMode.NO_SCALING,
dq_dtype: jnp.dtype = jnp.bfloat16, dq_dtype: jnp.dtype = jnp.bfloat16,
data_layout: str = "NN", data_layout: str = "NN",
...@@ -684,6 +707,7 @@ class ScaledTensorFactory: ...@@ -684,6 +707,7 @@ class ScaledTensorFactory:
colwise_data, colwise_data,
colwise_scale_inv, colwise_scale_inv,
amax, amax,
colwise_amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
data_layout=data_layout, data_layout=data_layout,
...@@ -698,7 +722,7 @@ class ScaledTensorFactory: ...@@ -698,7 +722,7 @@ class ScaledTensorFactory:
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
colwise_data, colwise_data,
colwise_scale_inv, colwise_scale_inv,
amax, colwise_amax if colwise_amax is not None else amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
is_colwise=is_colwise, is_colwise=is_colwise,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment