Unverified Commit 8a7ab3dd authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] NVFP4 support in TE/JAX (#2254)


Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent e99be1b6
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <iostream>
#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/hadamard_transform.h"
#include "transformer_engine/recipe.h"
#include "transformer_engine/transformer_engine.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
Error_Type RHTAmaxCalculationFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type amax_buf,
Result_Type post_rht_amax_buf,
int64_t rht_matrix_random_sign_mask_t, bool produce_regular_amax,
int64_t flatten_axis) {
NVTE_CHECK(input_buf.untyped_data() != nullptr,
"Input must be provided for RHT Amax calculation");
NVTE_CHECK(convert_ffi_datatype_to_te_dtype(input_buf.element_type()) == DType::kBFloat16,
"Input must be of type bfloat16 for RHT Amax calculation");
NVTE_CHECK(flatten_axis > 0 && flatten_axis < static_cast<int64_t>(input_buf.dimensions().size()),
"Flatten axis is out of bounds");
TensorWrapper input_tensor(input_buf.untyped_data(),
std::vector<size_t>{product(input_buf.dimensions(), 0, flatten_axis),
product(input_buf.dimensions(), flatten_axis,
input_buf.dimensions().size())},
convert_ffi_datatype_to_te_dtype(input_buf.element_type()));
float *amax_out = nullptr;
if (produce_regular_amax) {
amax_out = reinterpret_cast<float *>(amax_buf->untyped_data());
NVTE_CHECK(amax_out != nullptr, "Amax output must be provided for RHT Amax calculation");
NVTE_CHECK(convert_ffi_datatype_to_te_dtype(amax_buf->element_type()) == DType::kFloat32,
"Amax output must be of type float32 for RHT Amax calculation");
NVTE_CHECK(amax_buf->dimensions().size() == 1 && amax_buf->dimensions()[0] == 1,
"Amax output must be a single float for RHT Amax calculation");
}
float *post_rht_amax_out = reinterpret_cast<float *>(post_rht_amax_buf->untyped_data());
NVTE_CHECK(post_rht_amax_out != nullptr,
"Post-RHT Amax output must be provided for RHT Amax calculation");
NVTE_CHECK(convert_ffi_datatype_to_te_dtype(post_rht_amax_buf->element_type()) == DType::kFloat32,
"Post-RHT Amax output must be of type float32 for RHT Amax calculation");
NVTE_CHECK(post_rht_amax_buf->dimensions().size() == 1 && post_rht_amax_buf->dimensions()[0] == 1,
"Post-RHT Amax output must be a single float for RHT Amax calculation");
TensorWrapper out_tensor{};
out_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1});
out_tensor.set_columnwise_amax(post_rht_amax_out, DType::kFloat32, std::vector<size_t>{1});
// Zero'ing of amaxes is handled by TE common inside nvte_hadamard_transform_amax
nvte_hadamard_transform_amax(input_tensor.data(), out_tensor.data(),
0, // Regular amax for rowwise does not apply RHT so mask is 0
rht_matrix_random_sign_mask_t, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
RHTAmaxCalculationHandler, RHTAmaxCalculationFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // post_rht_amax
.Attr<int64_t>("rht_matrix_random_sign_mask_t") // rht_matrix_random_sign_mask_t
.Attr<bool>("produce_regular_amax") // produce_regular_amax
.Attr<int64_t>("flatten_axis"), // flatten_axis
FFI_CudaGraph_Traits);
Error_Type RHTAmaxCalculationInitializeFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type amax_buf, Result_Type post_rht_amax_buf,
int64_t rht_matrix_random_sign_mask_t,
bool produce_regular_amax, int64_t flatten_axis) {
return wrapInStreamCapture(std::function(RHTAmaxCalculationFFI), stream, input_buf, amax_buf,
post_rht_amax_buf, rht_matrix_random_sign_mask_t, produce_regular_amax,
flatten_axis);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
RHTAmaxCalculationInitializeHandler, RHTAmaxCalculationInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // post_rht_amax
.Attr<int64_t>("rht_matrix_random_sign_mask_t") // rht_matrix_random_sign_mask_t
.Attr<bool>("produce_regular_amax") // produce_regular_amax
.Attr<int64_t>("flatten_axis")); // flatten_axis
} // namespace jax
} // namespace transformer_engine
......@@ -41,6 +41,9 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
case xla::ffi::DataType::F8E8M0FNU:
return DType::kFloat8E8M0;
break;
case xla::ffi::DataType::F4E2M1FN:
return DType::kFloat4E2M1;
break;
default:
auto type_num = static_cast<XLA_FFI_DataType>(type);
NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
......
......@@ -102,6 +102,8 @@ inline static size_t te_dtype_bytes(const DType& type) {
return 1;
case DType::kFloat8E8M0:
return 1;
case DType::kFloat4E2M1:
return 1;
default:
NVTE_ERROR("Unsupported DType: ", static_cast<int>(type));
}
......
......@@ -51,7 +51,8 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
// Set scaling factor for quantized tensors
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands.");
NVTE_CHECK(is_nvfp4_scaling(scaling_mode) || typeToSize(input_dtype) == 1,
"Quantized GEMM requires 4-bit or 8-bit operands.");
NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM.");
std::vector<size_t> scale_shape = {1};
......@@ -74,7 +75,8 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias,
Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad,
Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta,
Result_Type output, Result_Type bias_grad,
Result_Type pre_gelu_out, Result_Type workspace,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed,
......@@ -119,6 +121,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI,
.Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input
.Arg<Buffer_Type>() // alpha
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
......@@ -136,11 +140,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI,
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out,
Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator,
JAXX_Collective_Op collective_op) {
Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type bias_grad,
Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode,
int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed,
bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad,
bool use_split_accumulator, JAXX_Collective_Op collective_op) {
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
......@@ -192,10 +196,31 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
std::vector<size_t> workspace_shape = {static_cast<size_t>(workspace->element_count()) - 256};
auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte);
// Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
float one = 1.;
float zero = 0.;
// alpha, beta
float *alpha_ptr = &one, *beta_ptr = &zero;
if (is_nvfp4_scaling(scaling_mode)) {
NVTE_CHECK(alpha.element_count() == 1 &&
convert_ffi_datatype_to_te_dtype(alpha.element_type()) == DType::kFloat32);
alpha_ptr = reinterpret_cast<float *>(alpha.untyped_data());
NVTE_CHECK(beta.element_count() == 1 &&
convert_ffi_datatype_to_te_dtype(beta.element_type()) == DType::kFloat32);
beta_ptr = reinterpret_cast<float *>(beta.untyped_data());
}
// Construct GEMM config
transformer_engine::MatmulConfigWrapper config;
config.set_use_split_accumulator(use_split_accumulator);
config.set_sm_count(num_math_sm);
if (fuse_bias) config.set_bias_tensor(bias_.data());
if (fuse_gelu) {
config.set_with_gelu_epilogue(true);
config.set_epilogue_aux_tensor(pre_gelu_.data());
}
if (collective_op == JAXX_Collective_Op::NONE) {
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
......@@ -205,9 +230,10 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size,
", out_shape[1]=", out_shape[1]);
nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(),
rhs_transposed, lhs_transposed, grad, workspace_.data(), false,
use_split_accumulator, num_math_sm, stream);
// Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
nvte_cublas_gemm_v2(rhs_transposed /*transa*/, lhs_transposed /*transb*/, alpha_ptr,
rhs_.data() /*A*/, lhs_.data() /*B*/, beta_ptr, out_.data() /*C*/,
out_.data() /*D*/, workspace_.data(), config, stream);
} else {
std::vector<size_t> buffer_shape{0, 0};
DType buffer_dtype = out_dtype;
......@@ -268,6 +294,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input
.Arg<Buffer_Type>() // alpha
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
......@@ -599,9 +627,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
// point to swizzled scale_inv data (store on workspace, only used for GEMM).
// Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers
auto lhs_sinv_shape_i =
get_mxfp8_scale_shape(lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise);
get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise);
auto rhs_sinv_shape_i =
get_mxfp8_scale_shape(rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise);
get_block_scale_shape(scaling_mode, rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise);
lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1];
rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1];
if (lhs_use_colwise) {
......
......@@ -26,11 +26,21 @@ std::vector<size_t> Shape::to_vector() const {
return shape;
}
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise) {
auto block_x = is_colwise ? MXFP8_BLOCK_SIZE.y : MXFP8_BLOCK_SIZE.x;
auto block_y = is_colwise ? MXFP8_BLOCK_SIZE.x : MXFP8_BLOCK_SIZE.y;
auto alignment_x = is_colwise ? MXFP8_ALIGNMENT.y : MXFP8_ALIGNMENT.x;
auto alignment_y = is_colwise ? MXFP8_ALIGNMENT.x : MXFP8_ALIGNMENT.y;
std::vector<size_t> get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N,
bool is_colwise) {
auto block_size = BLOCK_SIZE(1, 1);
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
block_size = MXFP8_BLOCK_SIZE;
} else if (scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING) {
block_size = NVFP4_BLOCK_SIZE;
} else {
NVTE_ERROR("Unsupported scaling_mode = ", static_cast<int>(scaling_mode));
}
auto block_x = is_colwise ? block_size.y : block_size.x;
auto block_y = is_colwise ? block_size.x : block_size.y;
auto alignment_x = is_colwise ? BLOCK_SCALE_ALIGNMENT.y : BLOCK_SCALE_ALIGNMENT.x;
auto alignment_y = is_colwise ? BLOCK_SCALE_ALIGNMENT.x : BLOCK_SCALE_ALIGNMENT.y;
NVTE_CHECK(M % block_x == 0, "M must be divisble by %zu (got %zu)", block_x, M);
NVTE_CHECK(N % block_y == 0, "N must be divisble by %zu (got %zu)", block_y, N);
......
......@@ -45,6 +45,8 @@ enum class JAXX_Scaling_Mode : int64_t {
DELAYED_TENSOR_SCALING = 1,
MXFP8_1D_SCALING = 2,
CURRENT_TENSOR_SCALING = 3,
NVFP4_1D_SCALING = 4,
NVFP4_2D_SCALING = 5,
};
inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) {
......@@ -56,6 +58,11 @@ inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING);
}
inline bool is_nvfp4_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING);
}
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
switch (mode) {
case JAXX_Scaling_Mode::NO_SCALING:
......@@ -70,22 +77,32 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING:
return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
break;
case JAXX_Scaling_Mode::NVFP4_1D_SCALING:
return NVTEScalingMode::NVTE_NVFP4_1D_SCALING;
break;
case JAXX_Scaling_Mode::NVFP4_2D_SCALING:
// TE common uses the same enum value for 1D and 2D fp4 scaling and instead differentiates them via quant_config.nvfp4_2d_quantization
return NVTEScalingMode::NVTE_NVFP4_1D_SCALING;
break;
default:
NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode));
break;
}
}
constexpr struct BlockSize {
struct BLOCK_SIZE {
size_t x;
size_t y;
} MXFP8_BLOCK_SIZE{1, 32};
constexpr struct Alignment {
size_t x;
size_t y;
} MXFP8_ALIGNMENT{128, 4};
constexpr BLOCK_SIZE(int _x, int _y) : x(_x), y(_y) {}
};
constexpr BLOCK_SIZE MXFP8_BLOCK_SIZE{1, 32};
constexpr BLOCK_SIZE NVFP4_BLOCK_SIZE{1, 16};
constexpr BLOCK_SIZE BLOCK_SCALE_ALIGNMENT{128, 4};
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise);
std::vector<size_t> get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N,
bool is_colwise);
template <typename T, typename... Rest>
void hash_combine(int64_t &seed, const T &v, Rest... rest) {
......
......@@ -76,6 +76,11 @@ pybind11::dict Registrations() {
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler));
// Amax
dict["te_rht_amax_ffi"] = pybind11::dict(
pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler));
return dict;
}
......@@ -106,7 +111,9 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("kFloat16", DType::kFloat16)
.value("kBFloat16", DType::kBFloat16)
.value("kFloat8E4M3", DType::kFloat8E4M3)
.value("kFloat8E5M2", DType::kFloat8E5M2);
.value("kFloat8E5M2", DType::kFloat8E5M2)
.value("kFloat8E8M0", DType::kFloat8E8M0)
.value("kFloat4E2M1", DType::kFloat4E2M1);
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local())
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
......@@ -165,6 +172,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
.value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING)
.value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING)
.value("NVFP4_1D_SCALING", JAXX_Scaling_Mode::NVFP4_1D_SCALING)
.value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING)
.export_values();
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
......
......@@ -5,8 +5,11 @@
************************************************************************/
#include <cuda_runtime.h>
#include <iostream>
#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/hadamard_transform.h"
#include "transformer_engine/recipe.h"
#include "transformer_engine/transformer_engine.h"
#include "xla/ffi/api/c_api.h"
......@@ -15,7 +18,7 @@ namespace transformer_engine {
namespace jax {
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
DType in_dtype, DType out_dtype, DType scale_dtype,
JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
......@@ -30,16 +33,22 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
// this function. We pass a dummy pointer as a workaround.
int temp = 0;
bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING;
auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto scale_shape = std::vector<size_t>{1};
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) {
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
if (is_nvfp4)
scale_shape = get_block_scale_shape(scaling_mode, batch_size, hidden_size, false);
output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), scale_dtype,
scale_shape);
}
}
......@@ -49,13 +58,16 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
if (is_nvfp4)
scale_shape =
get_block_scale_shape(scaling_mode, hidden_size, batch_size, false); //Transpose
output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), scale_dtype,
scale_shape);
}
}
if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4) {
output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
......@@ -72,17 +84,20 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
}
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Buffer_Type amax_buf, Result_Type output_buf,
Result_Type output_trans_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
bool is_dbias, int64_t flatten_axis) {
Buffer_Type amax_buf, Buffer_Type sr_rng_state,
Buffer_Type post_rht_amax_buf, Buffer_Type rht_matrix_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis,
bool stochastic_rounding, bool use_rht) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization.");
NVTE_CHECK(is_fp8_dtype(out_dtype) || is_fp4_dtype(out_dtype),
"Output datatype must be FP8 or FP4 for quantization.");
auto *input = input_buf.untyped_data();
......@@ -112,41 +127,106 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
bool const is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING;
NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4.");
NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling");
if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
if (is_tensor_scaling) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax == updated_amax && amax != nullptr,
"amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
if (is_tensor_scaling) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
}
if (is_nvfp4) {
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
NVTE_CHECK(amax != nullptr, "amax must be provided for NVFP4");
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
QuantizationConfigWrapper quant_config{};
if (scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING) {
quant_config.set_nvfp4_2d_quantization(true);
}
// Stochastic rounding
quant_config.set_stochastic_rounding(stochastic_rounding);
TensorWrapper sr_rng_state_tensor(sr_rng_state.untyped_data(), std::vector<size_t>{2},
DType::kInt64);
if (stochastic_rounding) {
NVTE_CHECK(sr_rng_state.size_bytes() == 2 * sizeof(uint64_t),
"rng_state must be of type int64[2]");
NVTE_CHECK(sr_rng_state.untyped_data() != nullptr, "rng_state must be provided for SR");
quant_config.set_rng_state(sr_rng_state_tensor.data());
}
if (quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
if (is_nvfp4 && use_rht) {
if (quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
// Do regular rowwise quantization without RHT
nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream);
}
TensorWrapper out_transpose(get_nvte_scaling_mode(scaling_mode));
// nvte_hadamard_transform_cast_fusion_columnwise expects the colwise data to be populated in the rowwise buffers on TensorWrapper
out_transpose.set_rowwise_data(output_trans, out_dtype, output_trans_shape);
auto const colwise_flatten_axis = output_trans_buf->dimensions().size() - flatten_axis;
out_transpose.set_rowwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0, colwise_flatten_axis),
product(colwise_scale_inv_buf->dimensions(), colwise_flatten_axis,
colwise_scale_inv_buf->dimensions().size())});
float *post_rht_amax = reinterpret_cast<float *>(post_rht_amax_buf.untyped_data());
NVTE_CHECK(post_rht_amax != nullptr, "Post-RHT colwise amax must be provided for NVFP4");
out_transpose.set_amax(post_rht_amax, DType::kFloat32, std::vector<size_t>{1});
bool const eligible_for_rht_cast_fusion =
input_tensor.dtype() == DType::kBFloat16 && m % 64 == 0 && n % 128 == 0;
NVTE_CHECK(eligible_for_rht_cast_fusion, "RHT cast fusion conditions not met");
NVTE_CHECK(
convert_ffi_datatype_to_te_dtype(rht_matrix_buf.element_type()) == DType::kBFloat16,
"RHT matrix must be bf16");
NVTE_CHECK(rht_matrix_buf.dimensions().size() == 2 && rht_matrix_buf.dimensions()[0] == 16 &&
rht_matrix_buf.dimensions()[1] == 16,
"RHT matrix must be 16x16");
TensorWrapper rht_matrix_tensor(rht_matrix_buf.untyped_data(), std::vector<size_t>{16, 16},
DType::kBFloat16);
nvte_hadamard_transform_cast_fusion_columnwise(input_tensor.data(), out_transpose.data(),
rht_matrix_tensor.data(), quant_config,
stream);
return ffi_with_cuda_error_check();
}
bool const is_colwise_transposed =
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4;
auto &tmp_shape = is_colwise_transposed ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf;
......@@ -156,26 +236,30 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
auto colwise_flatten_axis = flatten_axis;
if (is_colwise_transposed) {
// convert flatten_axis from N layout to T layout
colwise_flatten_axis = tmp_buf->dimensions().size() - flatten_axis;
}
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
product(tmp_buf->dimensions(), 0, colwise_flatten_axis),
product(tmp_buf->dimensions(), colwise_flatten_axis, tmp_buf->dimensions().size())});
}
}
if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
}
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
if (is_dbias) {
NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NVFP4_2D_SCALING,
"DBias quantization is not supported for NVFP4_2D_SCALING as fused dbias API cannot "
"take quant_config as input.");
nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
workspace_tensor.data(), stream);
} else {
nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream);
}
return ffi_with_cuda_error_check();
}
......@@ -186,6 +270,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // sr_rng_state
.Arg<Buffer_Type>() // colwise amax
.Arg<Buffer_Type>() // rht matrix
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
......@@ -196,7 +283,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout")
.Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis"),
.Attr<int64_t>("flatten_axis")
.Attr<bool>("stochastic_rounding")
.Attr<bool>("use_rht"),
FFI_CudaGraph_Traits);
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
......@@ -346,7 +435,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
sinv_size = 1;
} else {
const bool is_colwise = false;
auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise);
auto sinv_shape_i = get_block_scale_shape(scaling_mode, m_i, n, is_colwise);
out_i.set_rowwise_scale_inv(static_cast<void *>(sinv_ptr), sinv_dtype, sinv_shape_i);
sinv_size = product(sinv_shape_i);
}
......@@ -365,7 +454,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
colwise_sinv_size = 1;
} else {
const bool is_colwise = true;
auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise);
auto sinv_shape_i = get_block_scale_shape(scaling_mode, m_i, n, is_colwise);
out_i.set_columnwise_scale_inv(static_cast<void *>(colwise_sinv_ptr), sinv_dtype,
sinv_shape_i);
colwise_sinv_size = product(sinv_shape_i);
......
......@@ -16,7 +16,7 @@ import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .cpp_extensions.amax import AmaxScope
from .quantize import (
ScaledTensorFactory,
ScalingMode,
......
......@@ -15,7 +15,6 @@ from jax import lax
from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name
from transformer_engine.common import recipe
from ..dense import dense
......@@ -35,10 +34,9 @@ from ..cpp_extensions import (
from ..quantize import (
QuantizerFactory,
get_quantize_config,
QuantizeMeta,
QuantizeMetaSet,
ScalingMode,
TensorSource,
get_quantize_config_with_recipe,
)
PRNGKey = Any
......@@ -353,40 +351,32 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
Generate a set of FP8 meta for a GEMM.
"""
def generate_quantize_meta(quantizer_name: str):
collection_name = (
variable_collection
if variable_collection is not None
else get_quantize_config().COLLECTION_NAME
)
scale = self.variable(
collection_name,
f"{quantizer_name}{postfix}_scale",
jnp.ones,
(1,),
jnp.float32,
).value
amax_history = self.variable(
collection_name,
f"{quantizer_name}{postfix}_amax_history",
jnp.zeros,
(get_quantize_config().AMAX_HISTORY_LEN,),
jnp.float32,
).value
return QuantizeMeta(scale=scale, amax_history=amax_history)
if get_quantize_config().get_scaling_mode(
TensorSource.X
) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling):
x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad")
quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
kwargs = {"quantize_meta_set": quantize_meta_set}
collection_name = (
variable_collection
if variable_collection is not None
else get_quantize_config().COLLECTION_NAME
)
if fp8_recipe is None:
quantize_config = get_quantize_config()
else:
kwargs = {}
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs)
x_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.X, "x"
)
kernel_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.KERNEL, "kernel"
)
grad_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.DGRAD, "grad"
)
quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
quantizer_set = QuantizerFactory.create_set(
fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set
)
return quantizer_set
......
......@@ -16,7 +16,7 @@ import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .cpp_extensions.amax import AmaxScope
from .quantize import (
QuantizerSet,
......
......@@ -21,7 +21,7 @@ import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .cpp_extensions.amax import AmaxScope
from .layernorm import canonicalize_norm_type
from .quantize import (
with_sharding_constraint_by_logical_axes,
......
......@@ -14,5 +14,6 @@ from .quantizer import *
from .dequantizer import *
from .scaling_modes import *
from .metadata import *
from .hadamard import *
from .helper import *
from .device_utils import *
......@@ -15,6 +15,8 @@ import jax
import jax.numpy as jnp
from .scaling_modes import ScalingMode
from .hadamard import apply_rht, should_use_rht
__all__ = ["ScalingModeToDequantizerMap"]
......@@ -119,7 +121,7 @@ class BlockScaleDequantizer(Dequantizer):
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
)
data = data.reshape(
......@@ -161,10 +163,99 @@ class BlockScaleDequantizer(Dequantizer):
)
class NVFP4Dequantizer(Dequantizer):
"""NVFP4 Dequantizer Class.
This class provides static methods for dequantizing tensors that have been
quantized using NVFP4 scaling modes.
"""
@staticmethod
def _dequantize_func(data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis):
"""Dequantize a tensor using block scaling.
Args:
data: The quantized tensor data
scale_inv: The inverse scaling factors
amax: The maximum absolute value of the tensor
dq_dtype: The data type for dequantized values
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
Returns:
The dequantized tensor
"""
DATA_DTYPE_MAX = jnp.finfo(data.dtype).max.astype(jnp.float32)
SCALE_DTYPE_MAX = jnp.finfo(scale_inv.dtype).max.astype(jnp.float32)
tensor_scale_inv = amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX)
data = data.astype(jnp.float32)
scale_inv = scale_inv.astype(jnp.float32) * tensor_scale_inv
data_layout = "T" if is_colwise else "N"
data_shape = data.shape
flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaling_mode.get_scale_shape(
data_shape,
data_layout=data_layout,
is_colwise=is_colwise,
is_padded=False,
# expect the flatten_axis wrt the N layout
flatten_axis=flatten_axis if data_layout == "N" else len(data_shape) - flatten_axis,
broadcast_2d_scale_shape_to_1d=True,
)
data = data.reshape(
*data_shape[: flatten_axis - 1],
scale_shape[flatten_axis - 1],
int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*data_shape[flatten_axis:-1],
scale_shape[-1],
int(data_shape[-1] / scale_shape[-1]),
)
scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1))
out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape)
# Apply inverse of RHT if needed
use_rht = should_use_rht(scaling_mode, is_colwise=is_colwise)
if use_rht:
out = apply_rht(out, inverse=True)
return out
@staticmethod
def dequantize(scaled_tensor):
"""Dequantize a tensor using block scaling.
Args:
scaled_tensor: The quantized tensor to dequantize
Returns:
The dequantized tensor
"""
return NVFP4Dequantizer._dequantize_func(
scaled_tensor.data,
scaled_tensor.scale_inv,
scaled_tensor.amax,
scaled_tensor.dq_dtype,
scaled_tensor.scaling_mode,
scaled_tensor.is_colwise,
scaled_tensor.flatten_axis,
)
ScalingModeToDequantizerMap = {
ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer,
ScalingMode.NVFP4_1D_SCALING: NVFP4Dequantizer,
ScalingMode.NVFP4_2D_SCALING: NVFP4Dequantizer,
ScalingMode.NO_SCALING: NoopDequantizer,
}
......@@ -210,13 +301,13 @@ def _grouped_dequantize(grouped_scaled_tensor):
)
padded_scale_shape_i = scaling_mode.get_scale_shape(
data_shape_i,
grouped_scaled_tensor.is_colwise,
is_colwise=grouped_scaled_tensor.is_colwise,
is_padded=True,
flatten_axis=flatten_axis,
)
unpadded_scale_shape_i = scaling_mode.get_scale_shape(
data_shape_i,
grouped_scaled_tensor.is_colwise,
is_colwise=grouped_scaled_tensor.is_colwise,
is_padded=False,
flatten_axis=flatten_axis,
)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Randomized Hadamard Transform (RHT) utilities for JAX."""
import jax.numpy as jnp
from .scaling_modes import ScalingMode
def should_use_rht(scaling_mode, is_colwise=None, q_layout=None) -> bool:
"""Determine if RHT (Randomized Hadamard Transform) should be used.
Args:
scaling_mode: The scaling mode of the tensor.
is_colwise: Whether the tensor is column-wise. Only one of is_colwise or q_layout should be provided.
q_layout: The quantization layout of the tensor. Only one of is_colwise or q_layout should be provided.
Returns:
bool: True if RHT should be used, False otherwise.
"""
# Delayed import to avoid circular dependencies
from .quantizer import QuantizeLayout
assert (is_colwise is None) != (
q_layout is None
), "Exactly one of is_colwise or q_layout must be provided."
if q_layout is not None:
is_colwise = q_layout in {QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE}
return scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise
def get_wgrad_sign_vector() -> list[int]:
"""Get a fixed sign vector for the RHT used in NVFP4 weight gradient quantization."""
return [1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1]
def get_sign_from_vector(vector: list[int]) -> int:
"""Convert a sign vector to a bitmask integer."""
mask = 0
for i, v in enumerate(vector):
mask |= (v == -1) << i
return mask
def apply_rht(x: jnp.ndarray, inverse=False) -> jnp.ndarray:
"""Apply the Randomized Hadamard Transform (RHT) to the input tensor."""
h = get_rht_matrix()
block_size = 16
if inverse:
h = jnp.linalg.inv(h.astype(jnp.float32)).astype(jnp.bfloat16)
# TODO(jberchtold): These reshapes will break partitioning, fixme
return (x.reshape(-1, block_size) @ h).reshape(x.shape)
def get_rht_matrix() -> jnp.ndarray:
"""Get the Randomized Hadamard Transform (RHT) matrix used in NVFP4 weight gradient quantization.
Returns:
A (16, 16) bfloat16 matrix representing the RHT. This matrix is pre-multiplied by the random sign mask.
"""
import scipy
block_size = 16
h = jnp.array(scipy.linalg.hadamard(block_size))
# Apply the random sign mask
s = jnp.array(get_wgrad_sign_vector(), dtype=jnp.int32)
h = jnp.diag(s) @ h
return (h / jnp.sqrt(block_size)).astype(jnp.bfloat16)
......@@ -11,9 +11,12 @@ from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Tuple, Dict, Union, Sequence, Type
from functools import reduce
from typing import Optional, Tuple, Dict, Union, Sequence, Type, List
from functools import reduce, lru_cache
import operator
from importlib.metadata import version as get_pkg_version
import warnings
from packaging.version import Version as PkgVersion
import jax
import jax.numpy as jnp
......@@ -21,18 +24,27 @@ from flax.core.frozen_dict import FrozenDict
from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version
from transformer_engine.common import recipe
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
from transformer_engine.jax.sharding import (
global_shard_guard,
MeshResource,
num_of_devices,
get_all_mesh_axes,
with_sharding_constraint,
)
from .metadata import QuantizeMeta
from .scaling_modes import ScalingMode
from .. import cpp_extensions as tex
from .device_utils import get_device_compute_capability
__all__ = [
"get_quantize_config",
"get_quantize_config_with_recipe",
"fp8_autocast",
"is_fp8_available",
"is_scaling_mode_supported",
"get_supported_scaling_modes",
"get_supported_quantization_recipes",
"update_collections",
"get_delayed_scaling",
"apply_padding_to_scale_inv",
"remove_padding_from_scale_inv",
"NVTE_FP8_COLLECTION_NAME",
......@@ -41,11 +53,23 @@ __all__ = [
_is_fp8_available = None
_reason_for_no_fp8 = ""
_is_scaling_mode_supported = None
_reason_for_no_scaling_mode = ""
Collection = Union[Dict, FrozenDict]
NVTE_FP8_COLLECTION_NAME = "fp8_metas"
@lru_cache(maxsize=None)
def _jax_version_meet_requirement(version: str):
"""
Helper function checking if required JAX version is available
"""
jax_version = PkgVersion(get_pkg_version("jax"))
jax_version_required = PkgVersion(version)
return jax_version >= jax_version_required
def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
"""Check if delayed scaling FP8 is supported on the given GPU architecture.
......@@ -55,8 +79,6 @@ def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
Returns:
A tuple of (bool, str) indicating support and any error message
"""
if gpu_arch >= 90: # hopper and above
return True, ""
if gpu_arch < 89: # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if get_cublasLt_version() < 120103:
......@@ -75,20 +97,31 @@ def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
Returns:
A tuple of (bool, str) indicating support and any error message
"""
if gpu_arch >= 100: # blackwell and above
return True, ""
if gpu_arch < 99: # pre-blackwell
return False, "Device compute capability 9.9 or higher required for MXFP8 execution."
if get_cublasLt_version() < 120800:
return False, "CublasLt version 12.8.0 or higher required for MXFP8 execution."
if get_cuda_version() < 12010:
if get_cuda_version() < 12080:
return False, "Cuda version 12.8 or higher required for MXFP8 execution."
if not tex.jax_version_meet_requirement("0.5.3"):
if not _jax_version_meet_requirement("0.5.3"):
return False, "Jax version 0.5.3 or higher required for MXFP8 execution."
return True, ""
def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
def _check_fp4_support(gpu_arch) -> Tuple[bool, str]:
"""Check if FP4 is supported for the given GPU architecture."""
if gpu_arch < 100: # pre-blackwell
return False, "Device compute capability 10.0 or higher required for NVFP4 execution."
if get_cublasLt_version() < 120800:
return False, "CublasLt version 12.8.0 or higher required for NVFP4 execution."
if get_cuda_version() < 12080:
return False, "Cuda version 12.8 or higher required for NVFP4 execution."
if not _jax_version_meet_requirement("0.5.3"):
return False, "Jax version 0.5.3 or higher required for NVFP4 execution."
return True, ""
def _check_scaling_support(scaling_mode: ScalingMode, gpu_id: int) -> Tuple[bool, str]:
"""Check if FP8 is supported for the given scaling mode and GPU.
Args:
......@@ -101,9 +134,35 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
gpu_arch = get_device_compute_capability(gpu_id)
if scaling_mode.is_tensor_scaling():
return _check_delayed_scaling_fp8_support(gpu_arch)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
if scaling_mode.is_mxfp8_scaling:
return _check_block_scaling_fp8_support(gpu_arch)
return (False, "Unsupported scaling_mode!")
if scaling_mode.is_nvfp4_scaling:
return _check_fp4_support(gpu_arch)
return (True, "") # NO_SCALING is always supported
def is_scaling_mode_supported(
scaling_mode=ScalingMode.NO_SCALING,
gpu_id=None,
) -> Tuple[bool, str]:
"""Check if the given scaling mode is available for the given GPU."""
if gpu_id is not None:
return _check_scaling_support(scaling_mode, gpu_id)
global _is_scaling_mode_supported, _reason_for_no_scaling_mode
if _is_scaling_mode_supported is None:
_is_scaling_mode_supported = {}
_reason_for_no_scaling_mode = {}
if scaling_mode not in _is_scaling_mode_supported:
_is_scaling_mode_supported[scaling_mode] = True
_reason_for_no_scaling_mode[scaling_mode] = ""
for local_gpu_id in range(len(jax.local_devices())):
ret, msg = _check_scaling_support(scaling_mode, local_gpu_id)
if ret is False:
_is_scaling_mode_supported[scaling_mode] = ret
_reason_for_no_scaling_mode[scaling_mode] = msg
return ret, msg
return _is_scaling_mode_supported[scaling_mode], _reason_for_no_scaling_mode[scaling_mode]
def is_fp8_available(
......@@ -119,26 +178,36 @@ def is_fp8_available(
Returns:
A tuple of (bool, str) indicating availability and any error message
"""
if gpu_id is not None:
return _check_fp8_support(scaling_mode, gpu_id)
global _is_fp8_available, _reason_for_no_fp8
if _is_fp8_available is None:
_is_fp8_available = {}
_reason_for_no_fp8 = {}
if scaling_mode not in _is_fp8_available:
_is_fp8_available[scaling_mode] = True
_reason_for_no_fp8[scaling_mode] = ""
# JAX doesn't provide the local GPU id.
for local_gpu_id in range(len(jax.local_devices())):
ret, msg = _check_fp8_support(scaling_mode, local_gpu_id)
if ret is False:
_is_fp8_available[scaling_mode] = ret
_reason_for_no_fp8[scaling_mode] = msg
return ret, msg
return _is_fp8_available[scaling_mode], _reason_for_no_fp8[scaling_mode]
warnings.warn(
"is_fp8_available is deprecated. Use is_scaling_mode_supported instead.", DeprecationWarning
)
return is_scaling_mode_supported(scaling_mode=scaling_mode, gpu_id=gpu_id)
# TODO(Phuong): make the infrastruture to support NO_SCALING
def get_supported_scaling_modes() -> List[ScalingMode]:
"""Get all supported quantization scaling modes."""
return [
scaling_mode
for scaling_mode in ScalingMode
if is_scaling_mode_supported(scaling_mode=scaling_mode)[0]
and scaling_mode != ScalingMode.NO_SCALING
]
def get_supported_quantization_recipes() -> List[recipe.Recipe]:
"""Get all supported quantization recipes."""
# We don't support all the recipes TE/Common supports yet
# return [get_quantize_config_class(recipe)() for recipe in recipe.Recipe.__subclasses__()]
all_recipes = [
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
recipe.MXFP8BlockScaling(),
recipe.NVFP4BlockScaling(),
]
return [
recipe for recipe in all_recipes if get_quantize_config_class(recipe)().is_supported()[0]
]
def _format2dtypes(format_: recipe.Format):
......@@ -156,6 +225,8 @@ def _format2dtypes(format_: recipe.Format):
return jnp.float8_e5m2, jnp.float8_e5m2
if format_ == recipe.Format.HYBRID:
return jnp.float8_e4m3fn, jnp.float8_e5m2
if format_ == recipe.Format.E2M1:
return jnp.float4_e2m1fn, jnp.float4_e2m1fn
return jnp.bfloat16, jnp.bfloat16
......@@ -193,7 +264,6 @@ class BaseQuantizeConfig(ABC):
INITIALIZED: Whether the config has been initialized
MARGIN: Margin value for quantization
COLLECTION_NAME: Name of the collection for quantization metadata
FP8_FORMAT: FP8 format to use
FWD_DTYPE: Forward pass data type
BWD_DTYPE: Backward pass data type
FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass
......@@ -207,28 +277,26 @@ class BaseQuantizeConfig(ABC):
INITIALIZED = False
MARGIN: float = 0.0
COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
FP8_FORMAT: recipe.Format = recipe.Format.HYBRID
FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0]
BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1]
FWD_DTYPE: DType = None
BWD_DTYPE: DType = None
FP8_2X_ACC_FPROP: bool = False
FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False
INFERENCE_MODE: bool = False
# DelayedScaling
# TODO(Phuong): move these two into DelayedScalingQuantizeConfig
AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
"""Initialize the quantization configuration.
"""Initialize the quantization configuration from a given recipe.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
self.INITIALIZED = True
self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0
self.FP8_FORMAT = fp8_recipe.fp8_format
self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(self.FP8_FORMAT)
self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp8_format)
def is_fp8_enabled(self) -> bool:
"""Check if FP8 quantization is enabled.
......@@ -249,6 +317,27 @@ class BaseQuantizeConfig(ABC):
The scaling mode for the specified usage type.
"""
@abstractmethod
def get_quantize_flax_meta(
self,
module,
collection_name: str,
postfix: str,
tensor_source: TensorSource,
quantizer_name: str,
) -> QuantizeMeta:
"""Get the quantization metadata for a given Flax module.
Args:
module: The Flax module to get metadata for
collection_name: The name of the collection to store metadata in
postfix: Postfix to append to metadata names
tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
quantizer_name: The name of the quantizer within the module
Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
"""
def is_supported(self) -> tuple[bool, str]:
"""Check if this QuantizeConfig class is supported on the available devices.
......@@ -261,7 +350,7 @@ class BaseQuantizeConfig(ABC):
kernel_scaling_mode = self.get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = self.get_scaling_mode(TensorSource.DGRAD)
for scaling_mode in [x_scaling_mode, kernel_scaling_mode, grad_scaling_mode]:
is_supported, reason = is_fp8_available(scaling_mode=scaling_mode)
is_supported, reason = is_scaling_mode_supported(scaling_mode=scaling_mode)
if not is_supported:
return is_supported, reason
return True, None
......@@ -281,6 +370,27 @@ class NoOpQuantizeConfig(BaseQuantizeConfig):
"""Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.NO_SCALING
def get_quantize_flax_meta(
self,
module,
collection_name: str,
postfix: str,
tensor_source: TensorSource,
quantizer_name: str,
) -> QuantizeMeta:
"""Get the quantization metadata for a given Flax module.
Args:
module: The Flax module to get metadata for
collection_name: The name of the collection to store metadata in
postfix: Postfix to append to metadata names
tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
quantizer_name: The name of the quantizer within the module
Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
"""
return QuantizeMeta()
class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for delayed scaling FP8 recipe.
......@@ -299,6 +409,7 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
AssertionError: If recipe parameters are not supported
"""
super().initialize_from_recipe(fp8_recipe)
self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0
assert fp8_recipe.amax_compute_algo in [
"max",
......@@ -323,6 +434,41 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
"""Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.DELAYED_TENSOR_SCALING
def get_quantize_flax_meta(
self,
module,
collection_name: str,
postfix: str,
tensor_source: TensorSource,
quantizer_name: str,
) -> QuantizeMeta:
"""Get the quantization metadata for a given Flax module.
Args:
module: The Flax module to get metadata for
collection_name: The name of the collection to store metadata in
postfix: Postfix to append to metadata names
tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
quantizer_name: The name of the quantizer within the module
Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
"""
scale = module.variable(
collection_name,
f"{quantizer_name}{postfix}_scale",
jnp.ones,
(1,),
jnp.float32,
).value
amax_history = module.variable(
collection_name,
f"{quantizer_name}{postfix}_amax_history",
jnp.zeros,
(self.AMAX_HISTORY_LEN,),
jnp.float32,
).value
return QuantizeMeta(scale=scale, amax_history=amax_history)
class CurrentScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for current scaling FP8 recipe.
......@@ -344,6 +490,27 @@ class CurrentScalingQuantizeConfig(BaseQuantizeConfig):
"""Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.CURRENT_TENSOR_SCALING
def get_quantize_flax_meta(
self,
module,
collection_name: str,
postfix: str,
tensor_source: TensorSource,
quantizer_name: str,
) -> QuantizeMeta:
"""Get the quantization metadata for a given Flax module.
Args:
module: The Flax module to get metadata for
collection_name: The name of the collection to store metadata in
postfix: Postfix to append to metadata names
tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
quantizer_name: The name of the quantizer within the module
Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
"""
return QuantizeMeta()
class BlockScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for block scaling FP8 recipe.
......@@ -365,6 +532,91 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig):
"""Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.MXFP8_1D_SCALING
def get_quantize_flax_meta(
self,
module,
collection_name: str,
postfix: str,
tensor_source: TensorSource,
quantizer_name: str,
) -> QuantizeMeta:
"""Get the quantization metadata for a given Flax module.
Args:
module: The Flax module to get metadata for
collection_name: The name of the collection to store metadata in
postfix: Postfix to append to metadata names
tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
quantizer_name: The name of the quantizer within the module
Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
"""
return QuantizeMeta()
class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for NVFP4 scaling recipe.
This class provides specific initialization and finalization for NVFP4 scaling quantization mode.
"""
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
"""Initialize block scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
self.INITIALIZED = True
self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp4_format)
self.AMAX_HISTORY_LEN = 0
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type."""
if tensor_source == TensorSource.KERNEL:
return ScalingMode.NVFP4_2D_SCALING
# for x and grad
return ScalingMode.NVFP4_1D_SCALING
def get_quantize_flax_meta(
self,
module,
collection_name: str,
postfix: str,
tensor_source: TensorSource,
quantizer_name: str,
) -> QuantizeMeta:
"""Get the quantization metadata for a given Flax module.
Args:
module: The Flax module to get metadata for
collection_name: The name of the collection to store metadata in
postfix: Postfix to append to metadata names
tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
quantizer_name: The name of the quantizer within the module
Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
"""
if tensor_source != TensorSource.DGRAD:
# Only DGRAD uses stochastic rounding
return QuantizeMeta()
# TODO(jberchtold): This assumes SR is always enabled for NVFP4. Use flag from recipe to toggle it.
sr_jax_rng = module.make_rng("sr_rng")
# Get a unique key for this quantizer
sr_jax_rng = jax.jit(jax.random.fold_in)(
sr_jax_rng, hash(quantizer_name) % jnp.iinfo(jnp.int32).max
)
# Generate 4 random uint32 values from the JAX PRNG key
sr_jax_rng_state = jax.random.randint(
sr_jax_rng, (num_of_devices(), 4), 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32
).view(jnp.uint32)
sr_jax_rng_state = with_sharding_constraint(
sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None)
)
return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state)
_QUANTIZE_CONFIG = NoOpQuantizeConfig()
......@@ -377,7 +629,7 @@ def get_quantize_config():
def get_quantize_config_class(
fp8_recipe: recipe.Recipe,
) -> Type[BaseQuantizeConfig]:
"""Get the quantization configuration based on the FP8 recipe.
"""Get the quantization configuration class based on the FP8 recipe.
Args:
fp8_recipe: The FP8 recipe to use for initialization
......@@ -390,9 +642,18 @@ def get_quantize_config_class(
return BlockScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
return CurrentScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.NVFP4BlockScaling):
return NVFP4ScalingQuantizeConfig
raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}")
def get_quantize_config_with_recipe(fp8_recipe: recipe.Recipe):
"""Get the quantization configuration object based on the FP8 recipe."""
config = get_quantize_config_class(fp8_recipe)()
config.initialize_from_recipe(fp8_recipe)
return config
@contextmanager
def fp8_autocast(
enabled: bool = False,
......@@ -457,31 +718,6 @@ def fp8_autocast(
_QUANTIZE_CONFIG = old_quantize_config
def get_delayed_scaling():
r"""
Obtain an instance of DelayedScaling which is set via fp8_autocast.
.. note::
We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`
, and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in
recipe.DelayedScaling would be returned as the default values.
Returns
-------
delay_scaling : DelayedScaling
an instance of DelayedScaling which is set via fp8_autocast.
"""
amax_compute_algo = (
"max" if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent"
)
return recipe.DelayedScaling(
margin=int(get_quantize_config().MARGIN),
fp8_format=get_quantize_config().FP8_FORMAT,
amax_history_len=get_quantize_config().AMAX_HISTORY_LEN,
amax_compute_algo=amax_compute_algo,
)
def update_collections(new: Collection, original: Collection) -> Collection:
r"""Update collections with new values while preserving original structure.
......
......@@ -9,23 +9,29 @@ This module provides classes for managing quantization metadata, including
scale factors and amax history for different tensor types.
"""
from dataclasses import dataclass
import jax.numpy as jnp
__all__ = ["QuantizeMeta", "QuantizeMetaSet"]
@dataclass
class QuantizeMeta:
"""Metadata for quantization parameters.
Attributes:
For Delayed Scaling recipe:
scale: The scaling factor for quantization
amax_history: History of maximum absolute values
For NVFP4 recipe with Stochastic Rounding:
sr_rng_state: The state of the stochastic rounding RNG
"""
scale: jnp.ndarray
amax_history: jnp.ndarray
def __init__(self, **kwargs):
self._kwargs = kwargs
def get_kwargs_dictionary(self):
"""Get the metadata as a dictionary."""
return self._kwargs
@dataclass
......
......@@ -19,6 +19,7 @@ from transformer_engine_jax import QuantizeLayout
from transformer_engine.common import recipe
from .scaling_modes import ScalingMode
from .hadamard import apply_rht, should_use_rht
from .tensor import (
ScaledTensor,
ScaledTensor1x,
......@@ -28,7 +29,7 @@ from .tensor import (
)
from .helper import (
get_quantize_config,
get_quantize_config_class,
get_quantize_config_with_recipe,
AmaxComputeAlgo,
TensorSource,
)
......@@ -66,6 +67,7 @@ def compute_scale_from_amax(
sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
assert sf.shape == (1,)
return sf
......@@ -155,7 +157,7 @@ class Quantizer(ABC):
"""
def quantize(
self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1, **kwargs
self, x, is_rowwise=None, is_colwise=None, dq_dtype=None, flatten_axis=-1, **kwargs
) -> ScaledTensor:
"""Quantize a tensor using the internal _quantize_func().
......@@ -170,6 +172,18 @@ class Quantizer(ABC):
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
del kwargs
is_rowwise = (
is_rowwise
if is_rowwise is not None
else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
)
if (is_rowwise and is_colwise) or self.is_2x2x():
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = self._quantize_func(
......@@ -380,6 +394,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / self.scale
amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,))
# Note, this updating of amax here will only be called once because the "quantize" method impl inherited from CurrentScaleQuantizer only calls _quantize_func once then transposes the result for colwise quantization. So we don't have to worry about update being called twice for 2x2x quantization.
self.update(amax)
return ScaledTensorFactory.create_1x(
data=clipped_scaled_x,
......@@ -494,7 +509,7 @@ class BlockScaleQuantizer(Quantizer):
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
x_shape = x.shape
scale_shape = self.scaling_mode.get_scale_shape(
x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
x_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
)
scale_dtype = self.scaling_mode.get_scale_dtype()
x = x.reshape(
......@@ -563,6 +578,221 @@ class BlockScaleQuantizer(Quantizer):
return new_x.astype(dtype)
@register_pytree_node_class
@dataclass
class NVFP4Quantizer(Quantizer):
"""Quantizer implementation using current scaling.
This quantizer uses current scaling mode with float32 scales
Attributes:
scaling_mode: Set to NVFP4_1D_SCALING or NVFP4_2D_SCALING
q_layout: Quantization axis
data_layout: Data layout string (default: "NT")
stochastic_rounding_rng_state: RNG state for stochastic rounding, must be of shape (4,) and dtype uint32. If None, stochastic rounding is disabled.
"""
scaling_mode: ScalingMode = ScalingMode.NVFP4_1D_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
data_layout: str = "NT"
stochastic_rounding_rng_state: Optional[jnp.ndarray] = None
def __post_init__(self):
assert (
self.q_dtype == jnp.float4_e2m1fn
), "NVFP4 quantization must use a q_dtype of float4_e2m1fn"
assert self.scaling_mode.is_nvfp4_scaling, "NVFP4Quantizer must use NVFP4 scaling modes"
def _apply_stochastic_rounding(self, x):
assert (
self.stochastic_rounding_rng_state is not None
), "Stochastic rounding RNG state is not initialized"
assert self.stochastic_rounding_rng_state.shape == (
4,
), "Stochastic rounding RNG state must be of shape (4,)"
assert (
self.stochastic_rounding_rng_state.dtype == jnp.uint32
), "Stochastic rounding RNG state must be of dtype uint32"
# Default RNG state in JAX expects 2x 32-bit integers, use first 2 uint32s for initial state and fold in the other 2 uint32s
key_bits = jnp.array(
[
self.stochastic_rounding_rng_state[0],
self.stochastic_rounding_rng_state[1],
],
dtype=jnp.uint32,
)
key = jax.random.wrap_key_data(key_bits)
key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[2])
key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[3])
abs_x = jnp.abs(x)
sign_x = jnp.sign(x)
floor = (
(abs_x >= 0.5) * 0.5
+ (abs_x >= 1) * 0.5
+ (abs_x >= 2)
+ (abs_x >= 3)
+ (abs_x >= 4)
+ (abs_x >= 6) * 2
)
ceil = (
0.5
+ (abs_x > 0.5) * 0.5
+ (abs_x > 1) * 1
+ (abs_x > 2)
+ (abs_x > 3)
+ (abs_x > 4) * 2
)
frac = (abs_x - floor) / (ceil - floor)
rand = jax.random.uniform(key, abs_x.shape)
return sign_x * jnp.where(frac >= rand, ceil, floor)
def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Quantize function helper for block scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
# TODO(Phuong): use quantize_func from JAX
if flatten_axis < 0:
flatten_axis = x.ndim + flatten_axis
assert (
0 <= flatten_axis < x.ndim
), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}"
should_apply_rht = self.scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise
global_amax = None
if isinstance(x, NoScaleTensor):
global_amax = (
x.amax if not should_apply_rht else None
) # RHT changes the amax so don't use precalculated amax for colwise 1D nvfp4 quantization with RHT
x = x.data
# Transpose if required
rowwise_flatten_axis = flatten_axis
data_layout = self.data_layout[0]
if is_colwise:
x = jnp.transpose(x, (*range(flatten_axis, x.ndim), *range(flatten_axis)))
data_layout = self.data_layout[1]
# convert flatten_axis from N layout to T layout
flatten_axis = x.ndim - flatten_axis
x_shape = x.shape
if should_use_rht(self.scaling_mode, is_colwise=is_colwise):
# We only apply RHT for 1D colwise nvfp4
x = apply_rht(x)
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
scale_shape = self.scaling_mode.get_scale_shape(
x_shape,
data_layout=data_layout,
is_colwise=is_colwise,
is_padded=False,
flatten_axis=rowwise_flatten_axis,
)
scale_dtype = self.scaling_mode.get_scale_dtype()
x = x.reshape(
*x_shape[: flatten_axis - 1],
scale_shape[flatten_axis - 1],
int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*x_shape[flatten_axis:-1],
scale_shape[-1],
int(x_shape[-1] / scale_shape[-1]),
)
# Dtype max constants
DATA_DTYPE_MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32)
SCALE_DTYPE_MAX = jnp.finfo(scale_dtype).max.astype(jnp.float32)
# Level 1: Current Tensor Scaling
global_amax = (
global_amax
if global_amax is not None
else jnp.max(jnp.abs(x)).reshape((1,)).astype(jnp.float32)
)
tensor_scale = DATA_DTYPE_MAX * SCALE_DTYPE_MAX / global_amax
tensor_scale = jnp.minimum(
tensor_scale, jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32)
)
tensor_scale = jnp.where(
tensor_scale == jnp.array(0.0, dtype=jnp.float32),
jnp.array(1.0, dtype=jnp.float32),
tensor_scale,
)
tensor_scale_inv = 1.0 / tensor_scale
# Level 2: Block Scaling
block_amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True).astype(
jnp.float32
)
block_scale_inv = jnp.divide(block_amax, DATA_DTYPE_MAX)
block_scale_inv = block_scale_inv * tensor_scale
block_scale_inv = jnp.minimum(
block_scale_inv, jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32)
)
block_scale_inv = jnp.clip(block_scale_inv, -SCALE_DTYPE_MAX, SCALE_DTYPE_MAX)
# We cast block_scale_inv to scale_dtype here to account for any rounding during the cast. This will ensure the quantized data incorporates the rounded scale value into its computation so dequantization is accurate.
block_scale_inv = block_scale_inv.astype(scale_dtype)
# Note, with JIT jax removes this intermediate cast leading to slightly incorrect results during DQ and worse convergence to the original tensor during many samples of Q+SR->DQ. So we use reduce_precision to simulate the cast to scale_dtype.
assert scale_dtype == jnp.float8_e4m3fn, "Only float8_e4m3fn is supported for scale_dtype"
block_scale_inv = jax.lax.reduce_precision(block_scale_inv, 4, 3)
block_scale = jnp.minimum(
jnp.divide(1.0, block_scale_inv.astype(jnp.float32) * tensor_scale_inv),
jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32),
)
# Apply scaling
scaled_x = x.astype(jnp.float32) * block_scale
if self.stochastic_rounding_rng_state is not None:
scaled_x = self._apply_stochastic_rounding(scaled_x)
clipped_x = jnp.clip(scaled_x, -DATA_DTYPE_MAX, DATA_DTYPE_MAX)
# Cast to the right dtype
quantized_data = clipped_x.reshape(x_shape).astype(self.q_dtype)
block_scale_inv = block_scale_inv.reshape(scale_shape).astype(scale_dtype)
# In the 2D scaling mode, the scale shape is 2D but it needs to be broadcasted to 1D for GEMM.
# TODO(Phuong): expose this broadcast_2d_scale_shape_to_1d option to the
# quantizer.quantize() API
broadcasted_1d_scale_shape = self.scaling_mode.get_scale_shape(
x_shape,
data_layout=data_layout,
is_colwise=is_colwise,
is_padded=False,
flatten_axis=rowwise_flatten_axis,
broadcast_2d_scale_shape_to_1d=True,
)
# Broadcast and tile x to match the target shape
def repeat_to_shape(x, target_shape):
x_shape = x.shape
reps = [int(t // s) for s, t in zip(x_shape, target_shape)]
return jnp.tile(x, reps)
block_scale_inv = repeat_to_shape(block_scale_inv, broadcasted_1d_scale_shape)
return ScaledTensorFactory.create_1x(
data=quantized_data,
data_layout=data_layout,
is_colwise=is_colwise,
scale_inv=block_scale_inv,
amax=global_amax,
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
flatten_axis=rowwise_flatten_axis,
)
@register_pytree_node_class
@dataclass
class QuantizerSet:
......@@ -801,6 +1031,8 @@ class QuantizerFactory:
ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer,
ScalingMode.CURRENT_TENSOR_SCALING: CurrentScaleQuantizer,
ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer,
ScalingMode.NVFP4_1D_SCALING: NVFP4Quantizer,
ScalingMode.NVFP4_2D_SCALING: NVFP4Quantizer,
}
@staticmethod
......@@ -826,7 +1058,6 @@ class QuantizerFactory:
Returns:
A single quantizer or tuple of quantizers
"""
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type"
if n_groups:
if n_quantizers != 1:
......@@ -887,18 +1118,9 @@ class QuantizerFactory:
if "quantize_meta_set" in kwargs:
quantize_meta_set = kwargs.get("quantize_meta_set")
args_x = {
"scale": quantize_meta_set.x.scale,
"amax_history": quantize_meta_set.x.amax_history,
}
args_kernel = {
"scale": quantize_meta_set.kernel.scale,
"amax_history": quantize_meta_set.kernel.amax_history,
}
args_grad = {
"scale": quantize_meta_set.grad.scale,
"amax_history": quantize_meta_set.grad.amax_history,
}
args_x = quantize_meta_set.x.get_kwargs_dictionary()
args_kernel = quantize_meta_set.kernel.get_kwargs_dictionary()
args_grad = quantize_meta_set.grad.get_kwargs_dictionary()
else:
args_x = args_kernel = args_grad = {}
......@@ -919,6 +1141,7 @@ class QuantizerFactory:
bwd_dtype: jnp.dtype = None,
is_2x2x: bool = None,
n_groups: int = None,
# TODO(jberchtold): rename fp8_recipe to quantization_recipe
fp8_recipe: Optional[recipe.Recipe] = None,
**kwargs,
) -> tuple[Union[tuple[Quantizer], None]]:
......@@ -946,21 +1169,24 @@ class QuantizerFactory:
)
if fp8_recipe is not None:
quantize_config = get_quantize_config_class(fp8_recipe)()
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X)
kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD)
elif scaling_mode is not None:
x_scaling_mode = scaling_mode
kernel_scaling_mode = scaling_mode
grad_scaling_mode = scaling_mode
fwd_dtype = quantize_config.FWD_DTYPE
bwd_dtype = quantize_config.BWD_DTYPE
else:
x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X)
kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD)
if scaling_mode is not None:
x_scaling_mode = scaling_mode
kernel_scaling_mode = scaling_mode
grad_scaling_mode = scaling_mode
else:
x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X)
kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD)
fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE
bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE
fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE
bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE
if is_2x2x is None:
# TODO(Jeremy): check x, kernel, grad separately for 2x
if x_scaling_mode.is_1d_block_scaling():
......
......@@ -100,10 +100,19 @@ class ScalingModeMetadataImpl(ABC):
The data type used for scale tensors
"""
@abstractmethod
def get_data_layout(self) -> str:
"""Get the data layout for rowwise and colwise scaling.
Returns:
The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
"""
@abstractmethod
def get_scale_shape(
self,
data_shape: Tuple[int, ...],
data_layout: str = "N",
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
......@@ -112,6 +121,7 @@ class ScalingModeMetadataImpl(ABC):
Args:
data_shape: The shape of the tensor being quantized
data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
......@@ -156,13 +166,15 @@ class ScalingModeMetadataImpl(ABC):
input_shape,
unique_var,
flatten_axis,
broadcast_2d_scale_shape_to_1d,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
flatten_axis: Axis along which data can be flattened to 2D for quantization
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
Returns:
The Shardy rules for the scaling mode
......@@ -183,12 +195,22 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""
return jnp.float32
def get_data_layout(self) -> str:
"""Get the data layout for rowwise and colwise scaling.
Returns:
The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
"""
return "NN"
def get_scale_shape(
self,
data_shape: Tuple[int, ...],
data_layout: str = "N",
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
broadcast_2d_scale_shape_to_1d: bool = True,
) -> Tuple[int, ...]:
"""Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling.
......@@ -201,7 +223,14 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
Returns:
The shape for scale tensors - (1,)
"""
del data_shape, is_colwise, is_padded, flatten_axis
del (
data_shape,
data_layout,
is_colwise,
is_padded,
flatten_axis,
broadcast_2d_scale_shape_to_1d,
)
return (0,)
@lru_cache(maxsize=4)
......@@ -239,18 +268,20 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
input_shape,
unique_var,
flatten_axis,
broadcast_2d_scale_shape_to_1d,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
flatten_axis: Axis along which data can be flattened to 2D for quantization
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis
del flatten_axis, broadcast_2d_scale_shape_to_1d
input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
......@@ -270,25 +301,37 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""
return jnp.float32
def get_data_layout(self) -> str:
"""Get the data layout for rowwise and colwise scaling.
Returns:
The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
"""
return "NT"
def get_scale_shape(
self,
data_shape: Tuple[int, ...],
data_layout: str = "N",
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
broadcast_2d_scale_shape_to_1d: bool = True,
) -> Tuple[int, ...]:
"""Get the shape for scale tensors in delayed scaling.
Args:
data_shape: The shape of the tensor being scaled
data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True.
Returns:
The shape for scale tensors - (1,)
"""
del is_colwise
del data_layout, is_colwise, broadcast_2d_scale_shape_to_1d
if np.prod(data_shape) == 0:
return (0,)
return (1,)
......@@ -333,6 +376,7 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
input_shape,
unique_var,
flatten_axis,
broadcast_2d_scale_shape_to_1d,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
......@@ -340,11 +384,12 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis
del flatten_axis, broadcast_2d_scale_shape_to_1d
input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
......@@ -368,14 +413,18 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
_block_alignment: Alignment requirements for blocks
"""
def __init__(self, block_dims: Tuple[int]):
def __init__(self, block_dims: Tuple[int], scale_dtype: jnp.dtype, data_layout: str):
"""Initialize block scaling mode implementation.
Args:
block_dims: Dimensions of the scaling blocks
scale_dtype: Data type of the scale tensor
data_layout: Layout for rowwise and colwise scaling, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
"""
self._block_dims = block_dims
self._scale_dtype = scale_dtype
self._block_alignment = (128, 4)
self._data_layout = data_layout
def get_scale_dtype(self) -> jnp.dtype:
"""Get the data type for scale tensors in block scaling.
......@@ -383,7 +432,15 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
Returns:
The data type used for scale tensors (float8_e8m0fnu)
"""
return jnp.float8_e8m0fnu
return self._scale_dtype
def get_data_layout(self) -> str:
"""Get the data layout for rowwise and colwise scaling.
Returns:
The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
"""
return self._data_layout
def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim):
"""Remove excess padding from the scale shape and return the shape with respect to the original data shape."""
......@@ -411,23 +468,51 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
def get_scale_shape(
self,
data_shape: Tuple[int, ...],
data_layout: str = "N",
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
broadcast_2d_scale_shape_to_1d: bool = False,
) -> Tuple[int, ...]:
"""Get the shape for scale tensors in block scaling.
Args:
data_shape: The shape of the tensor being quantized
data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True.
Returns:
The shape for scale tensors
"""
flatten_axis = (len(data_shape) + flatten_axis) % len(data_shape)
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
block_alignment = self._block_alignment if is_padded else (1, 1)
if is_colwise:
assert data_layout == self._data_layout[1], (
f"Data layout must match colwise layout, received {data_layout} but expected"
f" {self._data_layout[1]}"
)
else:
assert data_layout == self._data_layout[0], (
f"Data layout must match rowwise layout, received {data_layout} but expected"
f" {self._data_layout[0]}"
)
if is_colwise and self._data_layout[1] == "T":
# TODO(Phuong): rework this hack so that we don't implicitly change is_colwise value
is_colwise = False # now rowwise in T is colwise in N
if flatten_axis < 0:
flatten_axis = len(data_shape) + flatten_axis
# flatten_axis is given wrt N layout, convert to T layout
flatten_axis = len(data_shape) - flatten_axis
if is_colwise:
block_y, block_x = self._block_dims
alignment_y, alignment_x = block_alignment
......@@ -435,12 +520,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
block_x, block_y = self._block_dims
alignment_x, alignment_y = block_alignment
if flatten_axis < 0:
flatten_axis = len(data_shape) + flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
is_block_2d = block_x > 1 and block_y > 1
assert data_shape[flatten_axis - 1] % block_x == 0, (
f"Data shape {data_shape} should be divisible by block_x {block_x} in axis"
f" {flatten_axis - 1}"
......@@ -449,6 +529,9 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
data_shape[-1] % block_y == 0
), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1"
if broadcast_2d_scale_shape_to_1d and is_block_2d:
block_x = 1
flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1)
flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1)
......@@ -575,6 +658,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
input_shape,
unique_var,
flatten_axis,
broadcast_2d_scale_shape_to_1d,
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
......@@ -582,30 +666,41 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
Returns:
The Shardy rules for the scaling mode
"""
# TODO(Phuong): to rework the shardy rule to handle transposes after NVFP4 is upstreamed
input_rank = len(input_shape)
input_spec = [f"{unique_var}_{i}" for i in range(input_rank)]
flatten_axis = (flatten_axis + input_rank) % input_rank
# This implementation needs to be updated for different block dims.
assert self._block_dims == (1, 32)
assert (
self._block_dims[1] != 1
), f"Expect 1D rowwise or 2D block. Got _block_dims={self._block_dims}"
# For 2D block scaling, only support when with broadcast_2d_scale_shape_to_1d
if self._block_dims[0] != 1:
assert self._block_dims[0] == self._block_dims[1] and broadcast_2d_scale_shape_to_1d, (
f"Got broadcast_2d_scale_shape_to_1d={broadcast_2d_scale_shape_to_1d},"
f" _block_dims={self._block_dims}"
)
block_size_1d = self._block_dims[1]
# We have to use two different factors in the two CompoundFactors because of Shardy
# verifier requirements, even though they are the same.
blocksizes = {}
colwise_var = f"{unique_var}_None"
rowwise_var = f"{unique_var}_None"
if not input_shape[-1] == 32:
if not input_shape[-1] == block_size_1d:
rowwise_var = input_spec[-1] + "_compound"
input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x")
blocksizes["blocksize_x"] = 32
if not input_shape[flatten_axis - 1] == 32:
blocksizes["blocksize_x"] = block_size_1d
if not input_shape[flatten_axis - 1] == block_size_1d:
colwise_var = input_spec[flatten_axis - 1] + "_compound"
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y")
blocksizes["blocksize_y"] = 32
blocksizes["blocksize_y"] = block_size_1d
# The rowwise and colwise scale tensors should be sharded the same way as the input.
# However, we need to adjust the dimensions where the block scaling factor applies.
......@@ -632,6 +727,8 @@ class ScalingMode(Enum):
- DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- CURRENT_TENSOR_SCALING: Uses current scaling with FP8 data type and float32 scales
- NVFP4_1D_SCALING: Uses block-based scaling with FP4 data type and E4M3 scales
- NVFP4_2D_SCALING: Uses block-based scaling with FP4 data type and E4M3 scales
- NO_SCALING: No scaling applied
"""
......@@ -639,6 +736,8 @@ class ScalingMode(Enum):
DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING
MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING
CURRENT_TENSOR_SCALING = JAXX_Scaling_Mode.CURRENT_TENSOR_SCALING
NVFP4_1D_SCALING = JAXX_Scaling_Mode.NVFP4_1D_SCALING
NVFP4_2D_SCALING = JAXX_Scaling_Mode.NVFP4_2D_SCALING
def _get_impl(self) -> ScalingModeMetadataImpl:
"""Get the implementation for this scaling mode.
......@@ -662,40 +761,79 @@ class ScalingMode(Enum):
"""
return self._get_impl().get_scale_dtype()
def get_scale_shape_2x(self, data_shape, is_padded=True, flatten_axis=-1) -> Tuple[Tuple[int]]:
def get_scale_shape_2x(
self, data_shape, is_padded=True, flatten_axis=-1, broadcast_2d_scale_shape_to_1d=False
) -> Tuple[Tuple[int]]:
"""Get shapes for both row-wise and column-wise scaling.
Args:
data_shape: Shape of the data tensor
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False.
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
data_layout = self._get_impl().get_data_layout()
rowwise_layout = data_layout[0]
assert (
rowwise_layout == "N"
), f"For rowwise layout only 'N' is supported, received {rowwise_layout}"
colwise_layout = data_layout[1]
rowwise_scale_shape = self.get_scale_shape(
data_shape, is_colwise=False, is_padded=is_padded, flatten_axis=flatten_axis
data_shape,
data_layout=rowwise_layout,
is_colwise=False,
is_padded=is_padded,
flatten_axis=flatten_axis,
broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d,
)
colwise_data_shape = data_shape
if colwise_layout == "T":
colwise_data_shape = data_shape[flatten_axis:] + data_shape[:flatten_axis]
colwise_scale_shape = self.get_scale_shape(
data_shape, is_colwise=True, is_padded=is_padded, flatten_axis=flatten_axis
colwise_data_shape,
data_layout=colwise_layout,
is_colwise=True,
is_padded=is_padded,
flatten_axis=flatten_axis,
broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d,
)
return (rowwise_scale_shape, colwise_scale_shape)
def get_scale_shape(
self, data_shape, is_colwise, is_padded=True, flatten_axis=-1
self,
data_shape,
data_layout="N",
is_colwise=False,
is_padded=True,
flatten_axis=-1,
broadcast_2d_scale_shape_to_1d=False,
) -> Tuple[int]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Shape of the data tensor
data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False.
Returns:
The shape for scale tensors
"""
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis)
return self._get_impl().get_scale_shape(
data_shape,
data_layout=data_layout,
is_colwise=is_colwise,
is_padded=is_padded,
flatten_axis=flatten_axis,
broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d,
)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
......@@ -713,6 +851,7 @@ class ScalingMode(Enum):
input_shape,
unique_var,
flatten_axis=-1,
broadcast_2d_scale_shape_to_1d=False,
) -> Tuple[Tuple[str]]:
"""Sharding rules for the input and (row, col)wise scale tensors.
......@@ -720,11 +859,14 @@ class ScalingMode(Enum):
input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False.
Returns:
The Shardy rules for the scaling mode
"""
return self._get_impl().get_shardy_sharding_rules(input_shape, unique_var, flatten_axis)
return self._get_impl().get_shardy_sharding_rules(
input_shape, unique_var, flatten_axis, broadcast_2d_scale_shape_to_1d
)
def get_grouped_scale_shape_2x(
self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1
......@@ -798,8 +940,64 @@ class ScalingMode(Enum):
Returns:
True if the scaling mode is 1D block scaling, False otherwise
"""
# Both 1D and 2D NVFP4 scaling are treated as 1D block scaling since the 2D scales are broadcast to 1D because it is required for the GEMM.
return self == ScalingMode.MXFP8_1D_SCALING or self.is_nvfp4_scaling
@property
def is_block_scaling(self) -> bool:
"""Check if this scaling mode is block scaling.
Returns:
True if the scaling mode is block scaling, False otherwise
"""
# Currently we only have 1D block scaling modes
return self.is_1d_block_scaling()
def get_compatible_q_dtypes(self) -> set[jnp.dtype]:
"""Returns a set of compatible quantized data types for this scaling mode.
Returns:
A set of compatible quantized data types
"""
if self in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
ScalingMode.MXFP8_1D_SCALING,
):
return {jnp.float8_e5m2, jnp.float8_e4m3fn}
if self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING):
return {jnp.float4_e2m1fn}
if self == ScalingMode.NO_SCALING:
return {jnp.float16, jnp.bfloat16, jnp.float32}
raise ValueError(f"Invalid scaling mode: {self}")
@property
def is_nvfp4_scaling(self) -> bool:
"""Check if this scaling mode is NVFP4 scaling.
Returns:
True if the scaling mode is NVFP4 scaling, False otherwise
"""
return self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING)
@property
def is_mxfp8_scaling(self) -> bool:
"""Check if this scaling mode is NVFP4 scaling.
Returns:
True if the scaling mode is NVFP4 scaling, False otherwise
"""
return self == ScalingMode.MXFP8_1D_SCALING
@property
def is_colwise_transposed(self) -> bool:
"""Check if this scaling mode uses transposed layout for column-wise scaling.
Returns:
True if the scaling mode uses transposed layout for column-wise scaling, False otherwise
"""
return self.is_tensor_scaling() or self.is_nvfp4_scaling
def __eq__(self, other):
"""Compare this scaling mode with another.
......@@ -836,9 +1034,20 @@ class ScalingMode(Enum):
SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(),
ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(),
ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)),
# WAR
ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(
block_dims=(1, 32),
scale_dtype=jnp.float8_e8m0fnu,
data_layout="NN",
),
ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(),
ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(),
ScalingMode.NVFP4_1D_SCALING: BlockScalingModeMetadataImpl(
block_dims=(1, 16),
scale_dtype=jnp.float8_e4m3fn,
data_layout="NT",
),
ScalingMode.NVFP4_2D_SCALING: BlockScalingModeMetadataImpl(
block_dims=(16, 16), scale_dtype=jnp.float8_e4m3fn, data_layout="NT"
),
}
......@@ -201,13 +201,32 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
else:
unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape,
data_layout=self.data_layout,
is_colwise=self.is_colwise,
is_padded=False,
flatten_axis=self.flatten_axis,
# expect the flatten_axis wrt the N layout
flatten_axis=(
self.flatten_axis
if self.data_layout == "N"
else self.data.ndim - self.flatten_axis
),
)
assert self.scale_inv.shape == unpadded_scale_shape, (
"Unpadded inverse scale factor has wrong shape, expected"
f" {unpadded_scale_shape} but got {self.scale_inv.shape}."
unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape(
self.data.shape,
data_layout=self.data_layout,
is_colwise=self.is_colwise,
is_padded=False,
# expect the flatten_axis wrt the N layout
flatten_axis=(
self.flatten_axis
if self.data_layout == "N"
else self.data.ndim - self.flatten_axis
),
broadcast_2d_scale_shape_to_1d=True,
)
assert self.scale_inv.shape in (unpadded_scale_shape, unpadded_scale_shape_broadcast), (
f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or"
f" {unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}."
)
def tree_flatten(self):
......@@ -583,6 +602,7 @@ class ScaledTensorFactory:
colwise_data,
colwise_scale_inv,
amax=None,
colwise_amax=None,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=jnp.bfloat16,
data_layout="NN",
......@@ -612,6 +632,8 @@ class ScaledTensorFactory:
"""
if amax is None:
amax = jnp.empty((1,), dtype=jnp.float32)
if colwise_amax is None:
colwise_amax = amax
assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
rowwise_tensor = ScaledTensorFactory.create_1x(
......@@ -630,10 +652,10 @@ class ScaledTensorFactory:
colwise_tensor = ScaledTensorFactory.create_1x(
colwise_data,
colwise_scale_inv,
amax,
colwise_amax,
scaling_mode,
dq_dtype,
is_colwise=True,
is_colwise=True, # TODO(Phuong): set this correctly
data_layout=data_layout[1],
flatten_axis=flatten_axis,
group_sizes=group_sizes,
......@@ -649,6 +671,7 @@ class ScaledTensorFactory:
colwise_data: jnp.ndarray,
colwise_scale_inv: jnp.ndarray,
amax=None,
colwise_amax=None,
scaling_mode: ScalingMode = ScalingMode.NO_SCALING,
dq_dtype: jnp.dtype = jnp.bfloat16,
data_layout: str = "NN",
......@@ -684,6 +707,7 @@ class ScaledTensorFactory:
colwise_data,
colwise_scale_inv,
amax,
colwise_amax,
scaling_mode,
dq_dtype,
data_layout=data_layout,
......@@ -698,7 +722,7 @@ class ScaledTensorFactory:
return ScaledTensorFactory.create_1x(
colwise_data,
colwise_scale_inv,
amax,
colwise_amax if colwise_amax is not None else amax,
scaling_mode,
dq_dtype,
is_colwise=is_colwise,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment