Unverified Commit 8a7ab3dd authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] NVFP4 support in TE/JAX (#2254)


Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent e99be1b6
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <iostream>
#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/hadamard_transform.h"
#include "transformer_engine/recipe.h"
#include "transformer_engine/transformer_engine.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
Error_Type RHTAmaxCalculationFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type amax_buf,
Result_Type post_rht_amax_buf,
int64_t rht_matrix_random_sign_mask_t, bool produce_regular_amax,
int64_t flatten_axis) {
NVTE_CHECK(input_buf.untyped_data() != nullptr,
"Input must be provided for RHT Amax calculation");
NVTE_CHECK(convert_ffi_datatype_to_te_dtype(input_buf.element_type()) == DType::kBFloat16,
"Input must be of type bfloat16 for RHT Amax calculation");
NVTE_CHECK(flatten_axis > 0 && flatten_axis < static_cast<int64_t>(input_buf.dimensions().size()),
"Flatten axis is out of bounds");
TensorWrapper input_tensor(input_buf.untyped_data(),
std::vector<size_t>{product(input_buf.dimensions(), 0, flatten_axis),
product(input_buf.dimensions(), flatten_axis,
input_buf.dimensions().size())},
convert_ffi_datatype_to_te_dtype(input_buf.element_type()));
float *amax_out = nullptr;
if (produce_regular_amax) {
amax_out = reinterpret_cast<float *>(amax_buf->untyped_data());
NVTE_CHECK(amax_out != nullptr, "Amax output must be provided for RHT Amax calculation");
NVTE_CHECK(convert_ffi_datatype_to_te_dtype(amax_buf->element_type()) == DType::kFloat32,
"Amax output must be of type float32 for RHT Amax calculation");
NVTE_CHECK(amax_buf->dimensions().size() == 1 && amax_buf->dimensions()[0] == 1,
"Amax output must be a single float for RHT Amax calculation");
}
float *post_rht_amax_out = reinterpret_cast<float *>(post_rht_amax_buf->untyped_data());
NVTE_CHECK(post_rht_amax_out != nullptr,
"Post-RHT Amax output must be provided for RHT Amax calculation");
NVTE_CHECK(convert_ffi_datatype_to_te_dtype(post_rht_amax_buf->element_type()) == DType::kFloat32,
"Post-RHT Amax output must be of type float32 for RHT Amax calculation");
NVTE_CHECK(post_rht_amax_buf->dimensions().size() == 1 && post_rht_amax_buf->dimensions()[0] == 1,
"Post-RHT Amax output must be a single float for RHT Amax calculation");
TensorWrapper out_tensor{};
out_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1});
out_tensor.set_columnwise_amax(post_rht_amax_out, DType::kFloat32, std::vector<size_t>{1});
// Zero'ing of amaxes is handled by TE common inside nvte_hadamard_transform_amax
nvte_hadamard_transform_amax(input_tensor.data(), out_tensor.data(),
0, // Regular amax for rowwise does not apply RHT so mask is 0
rht_matrix_random_sign_mask_t, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
RHTAmaxCalculationHandler, RHTAmaxCalculationFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // post_rht_amax
.Attr<int64_t>("rht_matrix_random_sign_mask_t") // rht_matrix_random_sign_mask_t
.Attr<bool>("produce_regular_amax") // produce_regular_amax
.Attr<int64_t>("flatten_axis"), // flatten_axis
FFI_CudaGraph_Traits);
Error_Type RHTAmaxCalculationInitializeFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type amax_buf, Result_Type post_rht_amax_buf,
int64_t rht_matrix_random_sign_mask_t,
bool produce_regular_amax, int64_t flatten_axis) {
return wrapInStreamCapture(std::function(RHTAmaxCalculationFFI), stream, input_buf, amax_buf,
post_rht_amax_buf, rht_matrix_random_sign_mask_t, produce_regular_amax,
flatten_axis);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
RHTAmaxCalculationInitializeHandler, RHTAmaxCalculationInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // post_rht_amax
.Attr<int64_t>("rht_matrix_random_sign_mask_t") // rht_matrix_random_sign_mask_t
.Attr<bool>("produce_regular_amax") // produce_regular_amax
.Attr<int64_t>("flatten_axis")); // flatten_axis
} // namespace jax
} // namespace transformer_engine
......@@ -41,6 +41,9 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
case xla::ffi::DataType::F8E8M0FNU:
return DType::kFloat8E8M0;
break;
case xla::ffi::DataType::F4E2M1FN:
return DType::kFloat4E2M1;
break;
default:
auto type_num = static_cast<XLA_FFI_DataType>(type);
NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
......
......@@ -102,6 +102,8 @@ inline static size_t te_dtype_bytes(const DType& type) {
return 1;
case DType::kFloat8E8M0:
return 1;
case DType::kFloat4E2M1:
return 1;
default:
NVTE_ERROR("Unsupported DType: ", static_cast<int>(type));
}
......
......@@ -51,7 +51,8 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
// Set scaling factor for quantized tensors
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands.");
NVTE_CHECK(is_nvfp4_scaling(scaling_mode) || typeToSize(input_dtype) == 1,
"Quantized GEMM requires 4-bit or 8-bit operands.");
NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM.");
std::vector<size_t> scale_shape = {1};
......@@ -74,7 +75,8 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias,
Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad,
Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta,
Result_Type output, Result_Type bias_grad,
Result_Type pre_gelu_out, Result_Type workspace,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed,
......@@ -119,6 +121,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI,
.Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input
.Arg<Buffer_Type>() // alpha
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
......@@ -136,11 +140,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI,
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out,
Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator,
JAXX_Collective_Op collective_op) {
Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type bias_grad,
Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode,
int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed,
bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad,
bool use_split_accumulator, JAXX_Collective_Op collective_op) {
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
......@@ -192,10 +196,31 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
std::vector<size_t> workspace_shape = {static_cast<size_t>(workspace->element_count()) - 256};
auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte);
// Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
float one = 1.;
float zero = 0.;
// alpha, beta
float *alpha_ptr = &one, *beta_ptr = &zero;
if (is_nvfp4_scaling(scaling_mode)) {
NVTE_CHECK(alpha.element_count() == 1 &&
convert_ffi_datatype_to_te_dtype(alpha.element_type()) == DType::kFloat32);
alpha_ptr = reinterpret_cast<float *>(alpha.untyped_data());
NVTE_CHECK(beta.element_count() == 1 &&
convert_ffi_datatype_to_te_dtype(beta.element_type()) == DType::kFloat32);
beta_ptr = reinterpret_cast<float *>(beta.untyped_data());
}
// Construct GEMM config
transformer_engine::MatmulConfigWrapper config;
config.set_use_split_accumulator(use_split_accumulator);
config.set_sm_count(num_math_sm);
if (fuse_bias) config.set_bias_tensor(bias_.data());
if (fuse_gelu) {
config.set_with_gelu_epilogue(true);
config.set_epilogue_aux_tensor(pre_gelu_.data());
}
if (collective_op == JAXX_Collective_Op::NONE) {
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
......@@ -205,9 +230,10 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size,
", out_shape[1]=", out_shape[1]);
nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(),
rhs_transposed, lhs_transposed, grad, workspace_.data(), false,
use_split_accumulator, num_math_sm, stream);
// Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
nvte_cublas_gemm_v2(rhs_transposed /*transa*/, lhs_transposed /*transb*/, alpha_ptr,
rhs_.data() /*A*/, lhs_.data() /*B*/, beta_ptr, out_.data() /*C*/,
out_.data() /*D*/, workspace_.data(), config, stream);
} else {
std::vector<size_t> buffer_shape{0, 0};
DType buffer_dtype = out_dtype;
......@@ -268,6 +294,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input
.Arg<Buffer_Type>() // alpha
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
......@@ -599,9 +627,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
// point to swizzled scale_inv data (store on workspace, only used for GEMM).
// Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers
auto lhs_sinv_shape_i =
get_mxfp8_scale_shape(lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise);
get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise);
auto rhs_sinv_shape_i =
get_mxfp8_scale_shape(rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise);
get_block_scale_shape(scaling_mode, rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise);
lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1];
rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1];
if (lhs_use_colwise) {
......
......@@ -26,11 +26,21 @@ std::vector<size_t> Shape::to_vector() const {
return shape;
}
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise) {
auto block_x = is_colwise ? MXFP8_BLOCK_SIZE.y : MXFP8_BLOCK_SIZE.x;
auto block_y = is_colwise ? MXFP8_BLOCK_SIZE.x : MXFP8_BLOCK_SIZE.y;
auto alignment_x = is_colwise ? MXFP8_ALIGNMENT.y : MXFP8_ALIGNMENT.x;
auto alignment_y = is_colwise ? MXFP8_ALIGNMENT.x : MXFP8_ALIGNMENT.y;
std::vector<size_t> get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N,
bool is_colwise) {
auto block_size = BLOCK_SIZE(1, 1);
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
block_size = MXFP8_BLOCK_SIZE;
} else if (scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING) {
block_size = NVFP4_BLOCK_SIZE;
} else {
NVTE_ERROR("Unsupported scaling_mode = ", static_cast<int>(scaling_mode));
}
auto block_x = is_colwise ? block_size.y : block_size.x;
auto block_y = is_colwise ? block_size.x : block_size.y;
auto alignment_x = is_colwise ? BLOCK_SCALE_ALIGNMENT.y : BLOCK_SCALE_ALIGNMENT.x;
auto alignment_y = is_colwise ? BLOCK_SCALE_ALIGNMENT.x : BLOCK_SCALE_ALIGNMENT.y;
NVTE_CHECK(M % block_x == 0, "M must be divisble by %zu (got %zu)", block_x, M);
NVTE_CHECK(N % block_y == 0, "N must be divisble by %zu (got %zu)", block_y, N);
......
......@@ -45,6 +45,8 @@ enum class JAXX_Scaling_Mode : int64_t {
DELAYED_TENSOR_SCALING = 1,
MXFP8_1D_SCALING = 2,
CURRENT_TENSOR_SCALING = 3,
NVFP4_1D_SCALING = 4,
NVFP4_2D_SCALING = 5,
};
inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) {
......@@ -56,6 +58,11 @@ inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING);
}
inline bool is_nvfp4_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING);
}
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
switch (mode) {
case JAXX_Scaling_Mode::NO_SCALING:
......@@ -70,22 +77,32 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING:
return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
break;
case JAXX_Scaling_Mode::NVFP4_1D_SCALING:
return NVTEScalingMode::NVTE_NVFP4_1D_SCALING;
break;
case JAXX_Scaling_Mode::NVFP4_2D_SCALING:
// TE common uses the same enum value for 1D and 2D fp4 scaling and instead differentiates them via quant_config.nvfp4_2d_quantization
return NVTEScalingMode::NVTE_NVFP4_1D_SCALING;
break;
default:
NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode));
break;
}
}
constexpr struct BlockSize {
struct BLOCK_SIZE {
size_t x;
size_t y;
} MXFP8_BLOCK_SIZE{1, 32};
constexpr struct Alignment {
size_t x;
size_t y;
} MXFP8_ALIGNMENT{128, 4};
constexpr BLOCK_SIZE(int _x, int _y) : x(_x), y(_y) {}
};
constexpr BLOCK_SIZE MXFP8_BLOCK_SIZE{1, 32};
constexpr BLOCK_SIZE NVFP4_BLOCK_SIZE{1, 16};
constexpr BLOCK_SIZE BLOCK_SCALE_ALIGNMENT{128, 4};
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise);
std::vector<size_t> get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N,
bool is_colwise);
template <typename T, typename... Rest>
void hash_combine(int64_t &seed, const T &v, Rest... rest) {
......
......@@ -76,6 +76,11 @@ pybind11::dict Registrations() {
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler));
// Amax
dict["te_rht_amax_ffi"] = pybind11::dict(
pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler));
return dict;
}
......@@ -106,7 +111,9 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("kFloat16", DType::kFloat16)
.value("kBFloat16", DType::kBFloat16)
.value("kFloat8E4M3", DType::kFloat8E4M3)
.value("kFloat8E5M2", DType::kFloat8E5M2);
.value("kFloat8E5M2", DType::kFloat8E5M2)
.value("kFloat8E8M0", DType::kFloat8E8M0)
.value("kFloat4E2M1", DType::kFloat4E2M1);
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local())
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
......@@ -165,6 +172,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
.value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING)
.value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING)
.value("NVFP4_1D_SCALING", JAXX_Scaling_Mode::NVFP4_1D_SCALING)
.value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING)
.export_values();
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
......
......@@ -5,8 +5,11 @@
************************************************************************/
#include <cuda_runtime.h>
#include <iostream>
#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/hadamard_transform.h"
#include "transformer_engine/recipe.h"
#include "transformer_engine/transformer_engine.h"
#include "xla/ffi/api/c_api.h"
......@@ -15,7 +18,7 @@ namespace transformer_engine {
namespace jax {
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
DType in_dtype, DType out_dtype, DType scale_dtype,
JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
......@@ -30,16 +33,22 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
// this function. We pass a dummy pointer as a workaround.
int temp = 0;
bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING;
auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto scale_shape = std::vector<size_t>{1};
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) {
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
if (is_nvfp4)
scale_shape = get_block_scale_shape(scaling_mode, batch_size, hidden_size, false);
output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), scale_dtype,
scale_shape);
}
}
......@@ -49,13 +58,16 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
if (is_nvfp4)
scale_shape =
get_block_scale_shape(scaling_mode, hidden_size, batch_size, false); //Transpose
output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), scale_dtype,
scale_shape);
}
}
if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4) {
output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
......@@ -72,17 +84,20 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
}
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Buffer_Type amax_buf, Result_Type output_buf,
Result_Type output_trans_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
bool is_dbias, int64_t flatten_axis) {
Buffer_Type amax_buf, Buffer_Type sr_rng_state,
Buffer_Type post_rht_amax_buf, Buffer_Type rht_matrix_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis,
bool stochastic_rounding, bool use_rht) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization.");
NVTE_CHECK(is_fp8_dtype(out_dtype) || is_fp4_dtype(out_dtype),
"Output datatype must be FP8 or FP4 for quantization.");
auto *input = input_buf.untyped_data();
......@@ -112,25 +127,27 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
bool const is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING;
NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4.");
NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling");
if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
if (is_tensor_scaling) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
float *amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax == updated_amax && amax != nullptr,
"amax must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{1});
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
......@@ -140,13 +157,76 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
scale_inv_buf->dimensions().size())});
}
}
if (is_nvfp4) {
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
NVTE_CHECK(amax != nullptr, "amax must be provided for NVFP4");
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
QuantizationConfigWrapper quant_config{};
if (scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING) {
quant_config.set_nvfp4_2d_quantization(true);
}
// Stochastic rounding
quant_config.set_stochastic_rounding(stochastic_rounding);
TensorWrapper sr_rng_state_tensor(sr_rng_state.untyped_data(), std::vector<size_t>{2},
DType::kInt64);
if (stochastic_rounding) {
NVTE_CHECK(sr_rng_state.size_bytes() == 2 * sizeof(uint64_t),
"rng_state must be of type int64[2]");
NVTE_CHECK(sr_rng_state.untyped_data() != nullptr, "rng_state must be provided for SR");
quant_config.set_rng_state(sr_rng_state_tensor.data());
}
if (quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
if (is_nvfp4 && use_rht) {
if (quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
// Do regular rowwise quantization without RHT
nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream);
}
TensorWrapper out_transpose(get_nvte_scaling_mode(scaling_mode));
// nvte_hadamard_transform_cast_fusion_columnwise expects the colwise data to be populated in the rowwise buffers on TensorWrapper
out_transpose.set_rowwise_data(output_trans, out_dtype, output_trans_shape);
auto const colwise_flatten_axis = output_trans_buf->dimensions().size() - flatten_axis;
out_transpose.set_rowwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0, colwise_flatten_axis),
product(colwise_scale_inv_buf->dimensions(), colwise_flatten_axis,
colwise_scale_inv_buf->dimensions().size())});
float *post_rht_amax = reinterpret_cast<float *>(post_rht_amax_buf.untyped_data());
NVTE_CHECK(post_rht_amax != nullptr, "Post-RHT colwise amax must be provided for NVFP4");
out_transpose.set_amax(post_rht_amax, DType::kFloat32, std::vector<size_t>{1});
bool const eligible_for_rht_cast_fusion =
input_tensor.dtype() == DType::kBFloat16 && m % 64 == 0 && n % 128 == 0;
NVTE_CHECK(eligible_for_rht_cast_fusion, "RHT cast fusion conditions not met");
NVTE_CHECK(
convert_ffi_datatype_to_te_dtype(rht_matrix_buf.element_type()) == DType::kBFloat16,
"RHT matrix must be bf16");
NVTE_CHECK(rht_matrix_buf.dimensions().size() == 2 && rht_matrix_buf.dimensions()[0] == 16 &&
rht_matrix_buf.dimensions()[1] == 16,
"RHT matrix must be 16x16");
TensorWrapper rht_matrix_tensor(rht_matrix_buf.untyped_data(), std::vector<size_t>{16, 16},
DType::kBFloat16);
nvte_hadamard_transform_cast_fusion_columnwise(input_tensor.data(), out_transpose.data(),
rht_matrix_tensor.data(), quant_config,
stream);
return ffi_with_cuda_error_check();
}
bool const is_colwise_transposed =
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4;
auto &tmp_shape = is_colwise_transposed ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf;
......@@ -156,26 +236,30 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
auto colwise_flatten_axis = flatten_axis;
if (is_colwise_transposed) {
// convert flatten_axis from N layout to T layout
colwise_flatten_axis = tmp_buf->dimensions().size() - flatten_axis;
}
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
product(tmp_buf->dimensions(), 0, colwise_flatten_axis),
product(tmp_buf->dimensions(), colwise_flatten_axis, tmp_buf->dimensions().size())});
}
if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
}
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
if (is_dbias) {
NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NVFP4_2D_SCALING,
"DBias quantization is not supported for NVFP4_2D_SCALING as fused dbias API cannot "
"take quant_config as input.");
nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
workspace_tensor.data(), stream);
} else {
nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream);
}
return ffi_with_cuda_error_check();
}
......@@ -186,6 +270,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // sr_rng_state
.Arg<Buffer_Type>() // colwise amax
.Arg<Buffer_Type>() // rht matrix
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
......@@ -196,7 +283,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout")
.Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis"),
.Attr<int64_t>("flatten_axis")
.Attr<bool>("stochastic_rounding")
.Attr<bool>("use_rht"),
FFI_CudaGraph_Traits);
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
......@@ -346,7 +435,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
sinv_size = 1;
} else {
const bool is_colwise = false;
auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise);
auto sinv_shape_i = get_block_scale_shape(scaling_mode, m_i, n, is_colwise);
out_i.set_rowwise_scale_inv(static_cast<void *>(sinv_ptr), sinv_dtype, sinv_shape_i);
sinv_size = product(sinv_shape_i);
}
......@@ -365,7 +454,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
colwise_sinv_size = 1;
} else {
const bool is_colwise = true;
auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise);
auto sinv_shape_i = get_block_scale_shape(scaling_mode, m_i, n, is_colwise);
out_i.set_columnwise_scale_inv(static_cast<void *>(colwise_sinv_ptr), sinv_dtype,
sinv_shape_i);
colwise_sinv_size = product(sinv_shape_i);
......
......@@ -16,7 +16,7 @@ import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .cpp_extensions.amax import AmaxScope
from .quantize import (
ScaledTensorFactory,
ScalingMode,
......
......@@ -15,7 +15,6 @@ from jax import lax
from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name
from transformer_engine.common import recipe
from ..dense import dense
......@@ -35,10 +34,9 @@ from ..cpp_extensions import (
from ..quantize import (
QuantizerFactory,
get_quantize_config,
QuantizeMeta,
QuantizeMetaSet,
ScalingMode,
TensorSource,
get_quantize_config_with_recipe,
)
PRNGKey = Any
......@@ -353,40 +351,32 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
Generate a set of FP8 meta for a GEMM.
"""
def generate_quantize_meta(quantizer_name: str):
collection_name = (
variable_collection
if variable_collection is not None
else get_quantize_config().COLLECTION_NAME
)
scale = self.variable(
collection_name,
f"{quantizer_name}{postfix}_scale",
jnp.ones,
(1,),
jnp.float32,
).value
amax_history = self.variable(
collection_name,
f"{quantizer_name}{postfix}_amax_history",
jnp.zeros,
(get_quantize_config().AMAX_HISTORY_LEN,),
jnp.float32,
).value
return QuantizeMeta(scale=scale, amax_history=amax_history)
if get_quantize_config().get_scaling_mode(
TensorSource.X
) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling):
x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad")
quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
kwargs = {"quantize_meta_set": quantize_meta_set}
if fp8_recipe is None:
quantize_config = get_quantize_config()
else:
kwargs = {}
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs)
x_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.X, "x"
)
kernel_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.KERNEL, "kernel"
)
grad_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.DGRAD, "grad"
)
quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
quantizer_set = QuantizerFactory.create_set(
fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set
)
return quantizer_set
......
......@@ -16,7 +16,7 @@ import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .cpp_extensions.amax import AmaxScope
from .quantize import (
QuantizerSet,
......
......@@ -21,7 +21,7 @@ import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .cpp_extensions.amax import AmaxScope
from .layernorm import canonicalize_norm_type
from .quantize import (
with_sharding_constraint_by_logical_axes,
......
......@@ -14,5 +14,6 @@ from .quantizer import *
from .dequantizer import *
from .scaling_modes import *
from .metadata import *
from .hadamard import *
from .helper import *
from .device_utils import *
......@@ -15,6 +15,8 @@ import jax
import jax.numpy as jnp
from .scaling_modes import ScalingMode
from .hadamard import apply_rht, should_use_rht
__all__ = ["ScalingModeToDequantizerMap"]
......@@ -119,7 +121,7 @@ class BlockScaleDequantizer(Dequantizer):
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
)
data = data.reshape(
......@@ -161,10 +163,99 @@ class BlockScaleDequantizer(Dequantizer):
)
class NVFP4Dequantizer(Dequantizer):
"""NVFP4 Dequantizer Class.
This class provides static methods for dequantizing tensors that have been
quantized using NVFP4 scaling modes.
"""
@staticmethod
def _dequantize_func(data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis):
"""Dequantize a tensor using block scaling.
Args:
data: The quantized tensor data
scale_inv: The inverse scaling factors
amax: The maximum absolute value of the tensor
dq_dtype: The data type for dequantized values
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
Returns:
The dequantized tensor
"""
DATA_DTYPE_MAX = jnp.finfo(data.dtype).max.astype(jnp.float32)
SCALE_DTYPE_MAX = jnp.finfo(scale_inv.dtype).max.astype(jnp.float32)
tensor_scale_inv = amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX)
data = data.astype(jnp.float32)
scale_inv = scale_inv.astype(jnp.float32) * tensor_scale_inv
data_layout = "T" if is_colwise else "N"
data_shape = data.shape
flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaling_mode.get_scale_shape(
data_shape,
data_layout=data_layout,
is_colwise=is_colwise,
is_padded=False,
# expect the flatten_axis wrt the N layout
flatten_axis=flatten_axis if data_layout == "N" else len(data_shape) - flatten_axis,
broadcast_2d_scale_shape_to_1d=True,
)
data = data.reshape(
*data_shape[: flatten_axis - 1],
scale_shape[flatten_axis - 1],
int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*data_shape[flatten_axis:-1],
scale_shape[-1],
int(data_shape[-1] / scale_shape[-1]),
)
scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1))
out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape)
# Apply inverse of RHT if needed
use_rht = should_use_rht(scaling_mode, is_colwise=is_colwise)
if use_rht:
out = apply_rht(out, inverse=True)
return out
@staticmethod
def dequantize(scaled_tensor):
"""Dequantize a tensor using block scaling.
Args:
scaled_tensor: The quantized tensor to dequantize
Returns:
The dequantized tensor
"""
return NVFP4Dequantizer._dequantize_func(
scaled_tensor.data,
scaled_tensor.scale_inv,
scaled_tensor.amax,
scaled_tensor.dq_dtype,
scaled_tensor.scaling_mode,
scaled_tensor.is_colwise,
scaled_tensor.flatten_axis,
)
ScalingModeToDequantizerMap = {
ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer,
ScalingMode.NVFP4_1D_SCALING: NVFP4Dequantizer,
ScalingMode.NVFP4_2D_SCALING: NVFP4Dequantizer,
ScalingMode.NO_SCALING: NoopDequantizer,
}
......@@ -210,13 +301,13 @@ def _grouped_dequantize(grouped_scaled_tensor):
)
padded_scale_shape_i = scaling_mode.get_scale_shape(
data_shape_i,
grouped_scaled_tensor.is_colwise,
is_colwise=grouped_scaled_tensor.is_colwise,
is_padded=True,
flatten_axis=flatten_axis,
)
unpadded_scale_shape_i = scaling_mode.get_scale_shape(
data_shape_i,
grouped_scaled_tensor.is_colwise,
is_colwise=grouped_scaled_tensor.is_colwise,
is_padded=False,
flatten_axis=flatten_axis,
)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Randomized Hadamard Transform (RHT) utilities for JAX."""
import jax.numpy as jnp
from .scaling_modes import ScalingMode
def should_use_rht(scaling_mode, is_colwise=None, q_layout=None) -> bool:
"""Determine if RHT (Randomized Hadamard Transform) should be used.
Args:
scaling_mode: The scaling mode of the tensor.
is_colwise: Whether the tensor is column-wise. Only one of is_colwise or q_layout should be provided.
q_layout: The quantization layout of the tensor. Only one of is_colwise or q_layout should be provided.
Returns:
bool: True if RHT should be used, False otherwise.
"""
# Delayed import to avoid circular dependencies
from .quantizer import QuantizeLayout
assert (is_colwise is None) != (
q_layout is None
), "Exactly one of is_colwise or q_layout must be provided."
if q_layout is not None:
is_colwise = q_layout in {QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE}
return scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise
def get_wgrad_sign_vector() -> list[int]:
"""Get a fixed sign vector for the RHT used in NVFP4 weight gradient quantization."""
return [1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1]
def get_sign_from_vector(vector: list[int]) -> int:
"""Convert a sign vector to a bitmask integer."""
mask = 0
for i, v in enumerate(vector):
mask |= (v == -1) << i
return mask
def apply_rht(x: jnp.ndarray, inverse=False) -> jnp.ndarray:
"""Apply the Randomized Hadamard Transform (RHT) to the input tensor."""
h = get_rht_matrix()
block_size = 16
if inverse:
h = jnp.linalg.inv(h.astype(jnp.float32)).astype(jnp.bfloat16)
# TODO(jberchtold): These reshapes will break partitioning, fixme
return (x.reshape(-1, block_size) @ h).reshape(x.shape)
def get_rht_matrix() -> jnp.ndarray:
"""Get the Randomized Hadamard Transform (RHT) matrix used in NVFP4 weight gradient quantization.
Returns:
A (16, 16) bfloat16 matrix representing the RHT. This matrix is pre-multiplied by the random sign mask.
"""
import scipy
block_size = 16
h = jnp.array(scipy.linalg.hadamard(block_size))
# Apply the random sign mask
s = jnp.array(get_wgrad_sign_vector(), dtype=jnp.int32)
h = jnp.diag(s) @ h
return (h / jnp.sqrt(block_size)).astype(jnp.bfloat16)
This diff is collapsed.
......@@ -9,23 +9,29 @@ This module provides classes for managing quantization metadata, including
scale factors and amax history for different tensor types.
"""
from dataclasses import dataclass
import jax.numpy as jnp
__all__ = ["QuantizeMeta", "QuantizeMetaSet"]
@dataclass
class QuantizeMeta:
"""Metadata for quantization parameters.
Attributes:
For Delayed Scaling recipe:
scale: The scaling factor for quantization
amax_history: History of maximum absolute values
For NVFP4 recipe with Stochastic Rounding:
sr_rng_state: The state of the stochastic rounding RNG
"""
scale: jnp.ndarray
amax_history: jnp.ndarray
def __init__(self, **kwargs):
self._kwargs = kwargs
def get_kwargs_dictionary(self):
"""Get the metadata as a dictionary."""
return self._kwargs
@dataclass
......
......@@ -201,13 +201,32 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
else:
unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape,
data_layout=self.data_layout,
is_colwise=self.is_colwise,
is_padded=False,
flatten_axis=self.flatten_axis,
# expect the flatten_axis wrt the N layout
flatten_axis=(
self.flatten_axis
if self.data_layout == "N"
else self.data.ndim - self.flatten_axis
),
)
assert self.scale_inv.shape == unpadded_scale_shape, (
"Unpadded inverse scale factor has wrong shape, expected"
f" {unpadded_scale_shape} but got {self.scale_inv.shape}."
unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape(
self.data.shape,
data_layout=self.data_layout,
is_colwise=self.is_colwise,
is_padded=False,
# expect the flatten_axis wrt the N layout
flatten_axis=(
self.flatten_axis
if self.data_layout == "N"
else self.data.ndim - self.flatten_axis
),
broadcast_2d_scale_shape_to_1d=True,
)
assert self.scale_inv.shape in (unpadded_scale_shape, unpadded_scale_shape_broadcast), (
f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or"
f" {unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}."
)
def tree_flatten(self):
......@@ -583,6 +602,7 @@ class ScaledTensorFactory:
colwise_data,
colwise_scale_inv,
amax=None,
colwise_amax=None,
scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=jnp.bfloat16,
data_layout="NN",
......@@ -612,6 +632,8 @@ class ScaledTensorFactory:
"""
if amax is None:
amax = jnp.empty((1,), dtype=jnp.float32)
if colwise_amax is None:
colwise_amax = amax
assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
rowwise_tensor = ScaledTensorFactory.create_1x(
......@@ -630,10 +652,10 @@ class ScaledTensorFactory:
colwise_tensor = ScaledTensorFactory.create_1x(
colwise_data,
colwise_scale_inv,
amax,
colwise_amax,
scaling_mode,
dq_dtype,
is_colwise=True,
is_colwise=True, # TODO(Phuong): set this correctly
data_layout=data_layout[1],
flatten_axis=flatten_axis,
group_sizes=group_sizes,
......@@ -649,6 +671,7 @@ class ScaledTensorFactory:
colwise_data: jnp.ndarray,
colwise_scale_inv: jnp.ndarray,
amax=None,
colwise_amax=None,
scaling_mode: ScalingMode = ScalingMode.NO_SCALING,
dq_dtype: jnp.dtype = jnp.bfloat16,
data_layout: str = "NN",
......@@ -684,6 +707,7 @@ class ScaledTensorFactory:
colwise_data,
colwise_scale_inv,
amax,
colwise_amax,
scaling_mode,
dq_dtype,
data_layout=data_layout,
......@@ -698,7 +722,7 @@ class ScaledTensorFactory:
return ScaledTensorFactory.create_1x(
colwise_data,
colwise_scale_inv,
amax,
colwise_amax if colwise_amax is not None else amax,
scaling_mode,
dq_dtype,
is_colwise=is_colwise,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment