Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
...@@ -6,13 +6,19 @@ ...@@ -6,13 +6,19 @@
#include "transformer_engine/gemm.h" #include "transformer_engine/gemm.h"
#include <memory> #include <memory>
#include <mutex>
#include <stdexcept>
#include <string_view> #include <string_view>
#include <tuple> #include <tuple>
#include "../extensions.h" #include "../extensions.h"
#include "cgemm_helper.h"
#include "common.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#include "common/util/string.h" #include "common/util/string.h"
#include "common/util/system.h" #include "common/util/system.h"
#include "cuda_runtime.h"
#include "nccl.h"
#include "transformer_engine/swizzle.h" #include "transformer_engine/swizzle.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
...@@ -45,7 +51,8 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand( ...@@ -45,7 +51,8 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
// Set scaling factor for quantized tensors // Set scaling factor for quantized tensors
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands."); NVTE_CHECK(is_nvfp4_scaling(scaling_mode) || typeToSize(input_dtype) == 1,
"Quantized GEMM requires 4-bit or 8-bit operands.");
NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM.");
std::vector<size_t> scale_shape = {1}; std::vector<size_t> scale_shape = {1};
...@@ -66,12 +73,78 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand( ...@@ -66,12 +73,78 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
return std::make_tuple(std::move(input), input_shape); return std::make_tuple(std::move(input), input_shape);
} }
Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias,
Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta,
Result_Type output, Result_Type bias_grad,
Result_Type pre_gelu_out, Result_Type workspace,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed,
bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad,
bool use_split_accumulator, JAXX_Collective_Op collective_op) {
nvte_cublas_handle_init();
// Init UB buffer
if (collective_op != JAXX_Collective_Op::NONE) {
auto &comm_handler = CommunicatorHandler::get();
std::vector<size_t> lhs_shape = {
product(lhs.dimensions(), 0, lhs_axis_boundary),
product(lhs.dimensions(), lhs_axis_boundary, lhs.dimensions().size())};
std::vector<size_t> rhs_shape = {
product(rhs.dimensions(), 0, rhs_axis_boundary),
product(rhs.dimensions(), rhs_axis_boundary, rhs.dimensions().size())};
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
std::vector<size_t> buffer_shape{0, 0};
DType buffer_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
if (collective_op == JAXX_Collective_Op::ALL_GATHER) {
buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size;
buffer_shape[1] = lhs_shape[1];
buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type());
} else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
buffer_shape[0] = out_shape[0];
buffer_shape[1] = out_shape[1];
}
auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, buffer_dtype,
collective_op);
}
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI,
FFI::Bind<FFI_Prepare>()
.Arg<Buffer_Type>() // lhs
.Arg<Buffer_Type>() // lhs_scale_inv
.Arg<Buffer_Type>() // rhs
.Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input
.Arg<Buffer_Type>() // alpha
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
.Ret<Buffer_Type>() // workspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("lhs_axis_boundary")
.Attr<int64_t>("rhs_axis_boundary")
.Attr<bool>("lhs_transposed")
.Attr<bool>("rhs_transposed")
.Attr<bool>("fuse_bias")
.Attr<bool>("fuse_gelu")
.Attr<bool>("grad")
.Attr<bool>("use_split_accumulator")
.Attr<JAXX_Collective_Op>("collective_op"));
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type bias_grad,
Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad,
bool use_split_accumulator, JAXX_Collective_Op collective_op) {
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
...@@ -83,31 +156,24 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -83,31 +156,24 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode, auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode,
rhs_axis_boundary, make_rhs_rowwise); rhs_axis_boundary, make_rhs_rowwise);
// Output tensor
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, "
"expected ",
out_.numel(), " elements ", to_string_like(out_shape), " but got ",
output->element_count(), " elements ", to_string_like(output->dimensions()));
// Bias input to forward pass or bias gradient output from backward pass // Bias input to forward pass or bias gradient output from backward pass
void *bias_ptr = nullptr; void *bias_ptr = nullptr;
std::vector<size_t> bias_shape = {0}; size_t bias_size = 0;
DType bias_dtype = out_dtype; DType bias_dtype = out_dtype;
if (fuse_bias) { if (fuse_bias) {
if (!grad) { if (grad) {
NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(), NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(),
"Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad"); "Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad");
} }
bias_ptr = bias_grad->untyped_data(); bias_ptr = bias.untyped_data();
bias_shape.at(0) = bias_grad->dimensions().front(); bias_size = product(bias.dimensions());
bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type()); bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type());
} }
auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); auto bias_ = TensorWrapper(bias_ptr, std::vector<size_t>{bias_size}, bias_dtype);
// Pre-GeLU output from forward pass or input to backward pass // Pre-GeLU output from forward pass or input to backward pass
void *pre_gelu_ptr = nullptr; void *pre_gelu_ptr = nullptr;
...@@ -130,12 +196,91 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -130,12 +196,91 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
std::vector<size_t> workspace_shape = {static_cast<size_t>(workspace->element_count()) - 256}; std::vector<size_t> workspace_shape = {static_cast<size_t>(workspace->element_count()) - 256};
auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte); auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte);
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
float one = 1.;
float zero = 0.;
// alpha, beta
float *alpha_ptr = &one, *beta_ptr = &zero;
if (is_nvfp4_scaling(scaling_mode)) {
NVTE_CHECK(alpha.element_count() == 1 &&
convert_ffi_datatype_to_te_dtype(alpha.element_type()) == DType::kFloat32);
alpha_ptr = reinterpret_cast<float *>(alpha.untyped_data());
NVTE_CHECK(beta.element_count() == 1 &&
convert_ffi_datatype_to_te_dtype(beta.element_type()) == DType::kFloat32);
beta_ptr = reinterpret_cast<float *>(beta.untyped_data());
}
// Construct GEMM config
transformer_engine::MatmulConfigWrapper config;
config.set_use_split_accumulator(use_split_accumulator);
config.set_sm_count(num_math_sm);
if (fuse_bias) config.set_bias_tensor(bias_.data());
if (fuse_gelu) {
config.set_with_gelu_epilogue(true);
config.set_epilogue_aux_tensor(pre_gelu_.data());
}
if (collective_op == JAXX_Collective_Op::NONE) {
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ",
to_string_like(out_shape), " but got ", output->element_count(), " elements ",
to_string_like(output->dimensions()));
NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size,
", out_shape[1]=", out_shape[1]);
// Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0); nvte_cublas_gemm_v2(rhs_transposed /*transa*/, lhs_transposed /*transb*/, alpha_ptr,
nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), rhs_.data() /*A*/, lhs_.data() /*B*/, beta_ptr, out_.data() /*C*/,
rhs_transposed, lhs_transposed, grad, workspace_.data(), false, out_.data() /*D*/, workspace_.data(), config, stream);
use_split_accumulator, num_math_sm, stream); } else {
std::vector<size_t> buffer_shape{0, 0};
DType buffer_dtype = out_dtype;
auto &comm_handler = CommunicatorHandler::get();
if (collective_op == JAXX_Collective_Op::ALL_GATHER) {
buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size;
buffer_shape[1] = lhs_shape[1];
out_shape[0] = out_shape[0] * comm_handler.tp_size;
buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type());
} else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
buffer_shape[0] = out_shape[0];
buffer_shape[1] = out_shape[1];
out_shape[0] = out_shape[0] / comm_handler.tp_size;
}
NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size,
", out_shape[1]=", out_shape[1]);
auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor(
buffer_shape, buffer_dtype, collective_op);
if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
auto ubuf_out_ = TensorWrapper(executor->get_ubuf_dptr(), buffer_shape, out_dtype);
// Prepare the auxiliary buffer for the reduce-scattered GEMM output
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(),
" elements ", to_string_like(out_shape), " but got ", output->element_count(),
" elements ", to_string_like(output->dimensions()));
// Launch GEMM+RS
executor->split_overlap_rs(rhs_, rhs_transposed, lhs_, lhs_transposed, ubuf_out_, bias_,
pre_gelu_, workspace_, grad, false, use_split_accumulator, out_,
stream);
} else if (collective_op == JAXX_Collective_Op::ALL_GATHER) {
auto aux_out_ = TensorWrapper(nullptr, std::vector<size_t>{0}, out_dtype); // Empty
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(),
" elements ", to_string_like(out_shape), " but got ", output->element_count(),
" elements ", to_string_like(output->dimensions()));
// Copy the distributed LHS operand into the local chunk of the communication buffer
executor->copy_into_buffer(stream, lhs_, true, make_lhs_rowwise);
// Launch AG+GEMM
executor->split_overlap_ag(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_,
workspace_, grad, false, use_split_accumulator, aux_out_, stream);
}
}
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -149,6 +294,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, ...@@ -149,6 +294,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Arg<Buffer_Type>() // rhs_scale_inv .Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias .Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input .Arg<Buffer_Type>() // gelu_input
.Arg<Buffer_Type>() // alpha
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad .Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out .Ret<Buffer_Type>() // pre_gelu_out
...@@ -161,15 +308,75 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, ...@@ -161,15 +308,75 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Attr<bool>("fuse_bias") .Attr<bool>("fuse_bias")
.Attr<bool>("fuse_gelu") .Attr<bool>("fuse_gelu")
.Attr<bool>("grad") .Attr<bool>("grad")
.Attr<bool>("use_split_accumulator"), .Attr<bool>("use_split_accumulator")
.Attr<JAXX_Collective_Op>("collective_op"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
size_t GroupedGemmGetGroupSizes(cudaStream_t stream, size_t num_gemms, int32_t *dev_group_sizes,
int32_t *host_group_sizes) {
static std::once_flag init_flag;
static cudaEvent_t d2h_event;
static size_t host_num_gemms;
static const size_t max_num_gemms = 1024;
//static int32_t host_group_sizes_internal[max_num_gemms];
static int32_t *host_group_sizes_internal = nullptr;
auto init = [&]() {
NVTE_CHECK_CUDA(cudaEventCreate(&d2h_event));
NVTE_CHECK_CUDA(cudaMallocHost(&host_group_sizes_internal, sizeof(int32_t) * max_num_gemms));
};
std::call_once(init_flag, init);
NVTE_CHECK(dev_group_sizes == nullptr || host_group_sizes == nullptr,
"Only one of dev_group_sizes and host_group_sizes can be non-nullptr.");
if (dev_group_sizes != nullptr) {
NVTE_CHECK(num_gemms <= max_num_gemms, "num_gemms ", num_gemms, " exceeds the maximum ",
"supported number ", max_num_gemms, " to be downloaded in advance.");
host_num_gemms = num_gemms;
// Wait for current compute stream to finish
cudaStream_t compute_stream_0 = nvte_get_compute_stream(0);
NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, stream));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_stream_0, d2h_event));
// Async copy group_sizes from device to host
size_t copy_bytes = sizeof(int32_t) * num_gemms;
NVTE_CHECK_CUDA(cudaMemcpyAsync(host_group_sizes_internal, dev_group_sizes, copy_bytes,
cudaMemcpyDeviceToHost, compute_stream_0));
NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, compute_stream_0));
return num_gemms;
}
if (host_group_sizes != nullptr) {
if (host_num_gemms == 0) return 0;
NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms,
" does not match the previous value ", host_num_gemms, ".");
// Wait for the async copy to finish, then copy group_sizes to user buffer
// Note: This may break cudaGraph.
NVTE_CHECK_CUDA(cudaEventSynchronize(d2h_event));
memcpy(host_group_sizes, host_group_sizes_internal, sizeof(int32_t) * host_num_gemms);
return host_num_gemms;
}
}
Error_Type GroupedGemmD2HGroupSizesFFI(cudaStream_t stream, Buffer_Type group_sizes,
Result_Type dummy_output, size_t num_gemms) {
int32_t *dev_group_sizes = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
GroupedGemmGetGroupSizes(stream, num_gemms, dev_group_sizes, nullptr);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGroupSizesFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // group_sizes
.Ret<Buffer_Type>() // dummy_output
.Attr<int64_t>("num_gemms"));
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias,
Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output,
Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans,
bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias,
bool is_grouped_dense_wgrad) { bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) {
// Notes on matrix layouts and transpose: // Notes on matrix layouts and transpose:
// Jax uses row-major data_layout, on entering this function, each input matrix pair: // Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major [m, k] for N - [k, m] for T // A: row-major [m, k] for N - [k, m] for T
...@@ -290,11 +497,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -290,11 +497,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
size_t dim_list_bytes = sizeof(int32_t) * num_gemms; size_t dim_list_bytes = sizeof(int32_t) * num_gemms;
std::vector<int32_t> dim_list_host(num_gemms); std::vector<int32_t> dim_list_host(num_gemms);
size_t host_num_gemms = 0;
if (use_async_d2h_group_sizes) {
host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data());
NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms,
" does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, ".");
} else {
auto dim_list_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data()); auto dim_list_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream); stream);
// Note: This may break cudaGraph. // Note: This may break cudaGraph.
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
}
size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
if (!is_grouped_dense_wgrad) { if (!is_grouped_dense_wgrad) {
NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m,
...@@ -413,9 +627,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -413,9 +627,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
// point to swizzled scale_inv data (store on workspace, only used for GEMM). // point to swizzled scale_inv data (store on workspace, only used for GEMM).
// Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers
auto lhs_sinv_shape_i = auto lhs_sinv_shape_i =
get_mxfp8_scale_shape(lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise);
auto rhs_sinv_shape_i = auto rhs_sinv_shape_i =
get_mxfp8_scale_shape(rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); get_block_scale_shape(scaling_mode, rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise);
lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1];
rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1];
if (lhs_use_colwise) { if (lhs_use_colwise) {
...@@ -553,7 +767,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, ...@@ -553,7 +767,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
.Attr<bool>("rhs_is_trans") .Attr<bool>("rhs_is_trans")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("has_bias") .Attr<bool>("has_bias")
.Attr<bool>("is_grouped_dense_wgrad")); .Attr<bool>("is_grouped_dense_wgrad")
.Attr<bool>("use_async_d2h_group_sizes"));
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -26,11 +26,21 @@ std::vector<size_t> Shape::to_vector() const { ...@@ -26,11 +26,21 @@ std::vector<size_t> Shape::to_vector() const {
return shape; return shape;
} }
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise) { std::vector<size_t> get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N,
auto block_x = is_colwise ? MXFP8_BLOCK_SIZE.y : MXFP8_BLOCK_SIZE.x; bool is_colwise) {
auto block_y = is_colwise ? MXFP8_BLOCK_SIZE.x : MXFP8_BLOCK_SIZE.y; auto block_size = BLOCK_SIZE(1, 1);
auto alignment_x = is_colwise ? MXFP8_ALIGNMENT.y : MXFP8_ALIGNMENT.x; if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
auto alignment_y = is_colwise ? MXFP8_ALIGNMENT.x : MXFP8_ALIGNMENT.y; block_size = MXFP8_BLOCK_SIZE;
} else if (scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING) {
block_size = NVFP4_BLOCK_SIZE;
} else {
NVTE_ERROR("Unsupported scaling_mode = ", static_cast<int>(scaling_mode));
}
auto block_x = is_colwise ? block_size.y : block_size.x;
auto block_y = is_colwise ? block_size.x : block_size.y;
auto alignment_x = is_colwise ? BLOCK_SCALE_ALIGNMENT.y : BLOCK_SCALE_ALIGNMENT.x;
auto alignment_y = is_colwise ? BLOCK_SCALE_ALIGNMENT.x : BLOCK_SCALE_ALIGNMENT.y;
NVTE_CHECK(M % block_x == 0, "M must be divisble by %zu (got %zu)", block_x, M); NVTE_CHECK(M % block_x == 0, "M must be divisble by %zu (got %zu)", block_x, M);
NVTE_CHECK(N % block_y == 0, "N must be divisble by %zu (got %zu)", block_y, N); NVTE_CHECK(N % block_y == 0, "N must be divisble by %zu (got %zu)", block_y, N);
......
...@@ -45,6 +45,8 @@ enum class JAXX_Scaling_Mode : int64_t { ...@@ -45,6 +45,8 @@ enum class JAXX_Scaling_Mode : int64_t {
DELAYED_TENSOR_SCALING = 1, DELAYED_TENSOR_SCALING = 1,
MXFP8_1D_SCALING = 2, MXFP8_1D_SCALING = 2,
CURRENT_TENSOR_SCALING = 3, CURRENT_TENSOR_SCALING = 3,
NVFP4_1D_SCALING = 4,
NVFP4_2D_SCALING = 5,
}; };
inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) { inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) {
...@@ -56,6 +58,11 @@ inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) { ...@@ -56,6 +58,11 @@ inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING); return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING);
} }
inline bool is_nvfp4_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING);
}
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
switch (mode) { switch (mode) {
case JAXX_Scaling_Mode::NO_SCALING: case JAXX_Scaling_Mode::NO_SCALING:
...@@ -70,22 +77,58 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { ...@@ -70,22 +77,58 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING: case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING:
return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
break; break;
case JAXX_Scaling_Mode::NVFP4_1D_SCALING:
return NVTEScalingMode::NVTE_NVFP4_1D_SCALING;
break;
case JAXX_Scaling_Mode::NVFP4_2D_SCALING:
// TE common uses the same enum value for 1D and 2D fp4 scaling and instead differentiates them via quant_config.nvfp4_2d_quantization
return NVTEScalingMode::NVTE_NVFP4_1D_SCALING;
break;
default: default:
NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode)); NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode));
break; break;
} }
} }
constexpr struct BlockSize { struct BLOCK_SIZE {
size_t x;
size_t y;
} MXFP8_BLOCK_SIZE{1, 32};
constexpr struct Alignment {
size_t x; size_t x;
size_t y; size_t y;
} MXFP8_ALIGNMENT{128, 4}; constexpr BLOCK_SIZE(int _x, int _y) : x(_x), y(_y) {}
};
constexpr BLOCK_SIZE MXFP8_BLOCK_SIZE{1, 32};
constexpr BLOCK_SIZE NVFP4_BLOCK_SIZE{1, 16};
constexpr BLOCK_SIZE BLOCK_SCALE_ALIGNMENT{128, 4};
std::vector<size_t> get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N,
bool is_colwise);
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); template <typename T, typename... Rest>
void hash_combine(int64_t &seed, const T &v, Rest... rest) {
seed ^= std::hash<T>{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
(hash_combine(seed, rest), ...);
}
enum class JAXX_Collective_Op : int64_t {
NONE = 0,
ALL_GATHER = 1,
REDUCE_SCATTER = 2,
};
static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) {
switch (op) {
case JAXX_Collective_Op::ALL_GATHER:
return CommOverlapType::AG;
break;
case JAXX_Collective_Op::REDUCE_SCATTER:
return CommOverlapType::RS;
break;
default:
NVTE_ERROR("Invalid Collective Op ", static_cast<int>(op));
break;
}
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -29,6 +29,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si ...@@ -29,6 +29,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape); output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape);
output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
// WAR: NVTE Norms query the is_training from whereas columwise_data is allocated // WAR: NVTE Norms query the is_training from whereas columwise_data is allocated
if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
...@@ -59,12 +60,13 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si ...@@ -59,12 +60,13 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
} }
Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf, Buffer_Type amax_buf, Buffer_Type gamma_buf, Buffer_Type beta_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf,
int norm_type, bool zero_centered_gamma, double epsilon, Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma,
int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) { double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode,
bool is_2x, bool output_amax_when_no_scaling) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
...@@ -77,9 +79,12 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -77,9 +79,12 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data(); auto *rsigma = rsigma_buf->untyped_data();
auto *mu = mu_buf->untyped_data(); auto *mu = mu_buf->untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto *workspace = wkspace_buf->untyped_data(); auto *workspace = wkspace_buf->untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type); auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type);
auto _is_2x = static_cast<bool>(is_2x); auto _is_2x = static_cast<bool>(is_2x);
...@@ -106,6 +111,10 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -106,6 +111,10 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape); output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
(scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
NVTE_CHECK( NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
...@@ -123,8 +132,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -123,8 +132,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
} }
if (_is_2x) { if (_is_2x) {
...@@ -162,13 +169,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, ...@@ -162,13 +169,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x .Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // gamma .Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta .Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise_output .Ret<Buffer_Type>() // colwise_output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // colwise_scale_inv .Ret<Buffer_Type>() // colwise_scale_inv
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // updated_amax
.Ret<Buffer_Type>() // mu .Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma .Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
...@@ -177,9 +185,51 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, ...@@ -177,9 +185,51 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<double>("epsilon") .Attr<double>("epsilon")
.Attr<int64_t>("sm_margin") .Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"), .Attr<bool>("is_2x")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type amax_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type,
bool zero_centered_gamma, double epsilon, int64_t sm_margin,
JAXX_Scaling_Mode scaling_mode, bool is_2x,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf,
gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf,
colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin,
scaling_mode, is_2x, output_amax_when_no_scaling);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise_output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // colwise_scale_inv
.Ret<Buffer_Type>() // updated_amax
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("norm_type")
.Attr<bool>("zero_centered_gamma")
.Attr<double>("epsilon")
.Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")
.Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, NVTE_Norm_Type norm_type, DType w_dtype, NVTE_Norm_Type norm_type,
bool zero_centered_gamma, int sm_margin) { bool zero_centered_gamma, int sm_margin) {
...@@ -305,5 +355,32 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardHandler, NormBackwardFFI, ...@@ -305,5 +355,32 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardHandler, NormBackwardFFI,
.Attr<int64_t>("sm_margin"), .Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type NormBackwardInitializeFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type mu_buf, Buffer_Type rsigma_buf,
Buffer_Type gamma_buf, Result_Type xgrad_buf,
Result_Type wgrad_buf, Result_Type dbeta_buf,
Result_Type wkspace_buf, int64_t norm_type,
bool zero_centered_gamma, int64_t sm_margin) {
return wrapInStreamCapture(std::function(NormBackwardFFI), stream, dz_buf, x_buf, mu_buf,
rsigma_buf, gamma_buf, xgrad_buf, wgrad_buf, dbeta_buf, wkspace_buf,
norm_type, zero_centered_gamma, sm_margin);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardInitializeHandler, NormBackwardInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // mu
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // dbeta
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("norm_type")
.Attr<bool>("zero_centered_gamma")
.Attr<int64_t>("sm_margin"));
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
************************************************************************/ ************************************************************************/
#include "../extensions.h" #include "../extensions.h"
#include "cgemm_helper.h"
#include "common/util/cuda_runtime.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -20,8 +22,12 @@ pybind11::dict Registrations() { ...@@ -20,8 +22,12 @@ pybind11::dict Registrations() {
pybind11::dict dict; pybind11::dict dict;
// Activation // Activation
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_ffi"] =
dict["te_dact_dbias_quantize_ffi"] = EncapsulateFFI(DActLuDBiasQuantizeHandler); pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(ActLuInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(ActLuHandler));
dict["te_dact_dbias_quantize_ffi"] = pybind11::dict(
pybind11::arg("initialize") = EncapsulateFFI(DActLuDBiasQuantizeInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(DActLuDBiasQuantizeHandler));
// Quantization // Quantization
dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler); dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler);
...@@ -42,9 +48,11 @@ pybind11::dict Registrations() { ...@@ -42,9 +48,11 @@ pybind11::dict Registrations() {
// Normalization // Normalization
dict["te_norm_forward_ffi"] = dict["te_norm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(NormForwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(NormForwardHandler)); pybind11::arg("execute") = EncapsulateFFI(NormForwardHandler));
dict["te_norm_backward_ffi"] = dict["te_norm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(NormBackwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler)); pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler));
// Attention // Attention
...@@ -57,14 +65,22 @@ pybind11::dict Registrations() { ...@@ -57,14 +65,22 @@ pybind11::dict Registrations() {
// GEMM // GEMM
dict["te_gemm_ffi"] = dict["te_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CollectiveGemmInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); pybind11::arg("execute") = EncapsulateFFI(GemmHandler));
// Grouped GEMM // Grouped GEMM
dict["te_grouped_gemm_d2h_group_sizes_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler));
dict["te_grouped_gemm_ffi"] = dict["te_grouped_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler));
// Amax
dict["te_rht_amax_ffi"] = pybind11::dict(
pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler));
return dict; return dict;
} }
...@@ -84,6 +100,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -84,6 +100,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported);
m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator);
m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
...@@ -93,7 +111,9 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -93,7 +111,9 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("kFloat16", DType::kFloat16) .value("kFloat16", DType::kFloat16)
.value("kBFloat16", DType::kBFloat16) .value("kBFloat16", DType::kBFloat16)
.value("kFloat8E4M3", DType::kFloat8E4M3) .value("kFloat8E4M3", DType::kFloat8E4M3)
.value("kFloat8E5M2", DType::kFloat8E5M2); .value("kFloat8E5M2", DType::kFloat8E5M2)
.value("kFloat8E8M0", DType::kFloat8E8M0)
.value("kFloat4E2M1", DType::kFloat4E2M1);
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local())
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
...@@ -133,6 +153,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -133,6 +153,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("QGEGLU", NVTE_Activation_Type::QGEGLU) .value("QGEGLU", NVTE_Activation_Type::QGEGLU)
.value("SRELU", NVTE_Activation_Type::SRELU) .value("SRELU", NVTE_Activation_Type::SRELU)
.value("SREGLU", NVTE_Activation_Type::SREGLU) .value("SREGLU", NVTE_Activation_Type::SREGLU)
.value("CLAMPED_SWIGLU", NVTE_Activation_Type::CLAMPED_SWIGLU)
.export_values(); .export_values();
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
...@@ -151,6 +172,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -151,6 +172,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
.value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING) .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING)
.value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) .value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING)
.value("NVFP4_1D_SCALING", JAXX_Scaling_Mode::NVFP4_1D_SCALING)
.value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING)
.export_values(); .export_values();
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout", pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
...@@ -159,6 +182,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -159,6 +182,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE) .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
.export_values(); .export_values();
pybind11::enum_<JAXX_Collective_Op>(m, "JAXX_Collective_Op", pybind11::module_local())
.value("NONE", JAXX_Collective_Op::NONE)
.value("ALL_GATHER", JAXX_Collective_Op::ALL_GATHER)
.value("REDUCE_SCATTER", JAXX_Collective_Op::REDUCE_SCATTER)
.export_values();
} }
} // namespace jax } // namespace jax
......
...@@ -5,8 +5,11 @@ ...@@ -5,8 +5,11 @@
************************************************************************/ ************************************************************************/
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream>
#include "../extensions.h" #include "../extensions.h"
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "transformer_engine/hadamard_transform.h"
#include "transformer_engine/recipe.h" #include "transformer_engine/recipe.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
...@@ -15,7 +18,7 @@ namespace transformer_engine { ...@@ -15,7 +18,7 @@ namespace transformer_engine {
namespace jax { namespace jax {
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype, DType scale_dtype,
JAXX_Scaling_Mode scaling_mode, JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout) { QuantizeLayout q_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
...@@ -30,16 +33,22 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -30,16 +33,22 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
// this function. We pass a dummy pointer as a workaround. // this function. We pass a dummy pointer as a workaround.
int temp = 0; int temp = 0;
bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING;
auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype); auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto scale_shape = std::vector<size_t>{1};
// Only the pointers will be checked for scale_inv, thus the shapes do not matter // Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) { if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) {
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape); output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) { if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32, if (is_nvfp4)
std::vector<size_t>{1}); scale_shape = get_block_scale_shape(scaling_mode, batch_size, hidden_size, false);
output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), scale_dtype,
scale_shape);
} }
} }
...@@ -49,13 +58,16 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -49,13 +58,16 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape); output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter // Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) { if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32, if (is_nvfp4)
std::vector<size_t>{1}); scale_shape =
get_block_scale_shape(scaling_mode, hidden_size, batch_size, false); //Transpose
output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), scale_dtype,
scale_shape);
} }
} }
if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4) {
output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32, output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1}); std::vector<size_t>{1});
output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32, output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
...@@ -72,17 +84,20 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -72,17 +84,20 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
} }
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Buffer_Type amax_buf, Result_Type output_buf, Buffer_Type amax_buf, Buffer_Type sr_rng_state,
Result_Type output_trans_buf, Result_Type scale_inv_buf, Buffer_Type post_rht_amax_buf, Buffer_Type rht_matrix_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type output_buf, Result_Type output_trans_buf,
Result_Type dbias_buf, Result_Type workspace_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, Result_Type updated_amax_buf, Result_Type dbias_buf,
bool is_dbias, int64_t flatten_axis) { Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis,
bool stochastic_rounding, bool use_rht) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization."); NVTE_CHECK(is_fp8_dtype(out_dtype) || is_fp4_dtype(out_dtype),
"Output datatype must be FP8 or FP4 for quantization.");
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
...@@ -112,12 +127,17 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -112,12 +127,17 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
bool const is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING;
NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4.");
NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling");
if (quantize_layout == QuantizeLayout::ROWWISE || if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
if (is_tensor_scaling) { if (is_tensor_scaling) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data()); float *amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
...@@ -127,8 +147,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -127,8 +147,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv( output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(), scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
std::vector<size_t>{1});
} else { } else {
output_tensor.set_rowwise_scale_inv( output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(), scale_inv_buf->untyped_data(),
...@@ -138,13 +157,76 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -138,13 +157,76 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
scale_inv_buf->dimensions().size())}); scale_inv_buf->dimensions().size())});
} }
} }
if (is_nvfp4) {
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
NVTE_CHECK(amax != nullptr, "amax must be provided for NVFP4");
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
QuantizationConfigWrapper quant_config{};
if (scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING) {
quant_config.set_nvfp4_2d_quantization(true);
}
// Stochastic rounding
quant_config.set_stochastic_rounding(stochastic_rounding);
TensorWrapper sr_rng_state_tensor(sr_rng_state.untyped_data(), std::vector<size_t>{2},
DType::kInt64);
if (stochastic_rounding) {
NVTE_CHECK(sr_rng_state.size_bytes() == 2 * sizeof(uint64_t),
"rng_state must be of type int64[2]");
NVTE_CHECK(sr_rng_state.untyped_data() != nullptr, "rng_state must be provided for SR");
quant_config.set_rng_state(sr_rng_state_tensor.data());
} }
if (quantize_layout == QuantizeLayout::COLWISE || if (quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) if (is_nvfp4 && use_rht) {
? output_trans_shape if (quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
: output_shape; // Do regular rowwise quantization without RHT
nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream);
}
TensorWrapper out_transpose(get_nvte_scaling_mode(scaling_mode));
// nvte_hadamard_transform_cast_fusion_columnwise expects the colwise data to be populated in the rowwise buffers on TensorWrapper
out_transpose.set_rowwise_data(output_trans, out_dtype, output_trans_shape);
auto const colwise_flatten_axis = output_trans_buf->dimensions().size() - flatten_axis;
out_transpose.set_rowwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0, colwise_flatten_axis),
product(colwise_scale_inv_buf->dimensions(), colwise_flatten_axis,
colwise_scale_inv_buf->dimensions().size())});
float *post_rht_amax = reinterpret_cast<float *>(post_rht_amax_buf.untyped_data());
NVTE_CHECK(post_rht_amax != nullptr, "Post-RHT colwise amax must be provided for NVFP4");
out_transpose.set_amax(post_rht_amax, DType::kFloat32, std::vector<size_t>{1});
bool const eligible_for_rht_cast_fusion =
input_tensor.dtype() == DType::kBFloat16 && m % 64 == 0 && n % 128 == 0;
NVTE_CHECK(eligible_for_rht_cast_fusion, "RHT cast fusion conditions not met");
NVTE_CHECK(
convert_ffi_datatype_to_te_dtype(rht_matrix_buf.element_type()) == DType::kBFloat16,
"RHT matrix must be bf16");
NVTE_CHECK(rht_matrix_buf.dimensions().size() == 2 && rht_matrix_buf.dimensions()[0] == 16 &&
rht_matrix_buf.dimensions()[1] == 16,
"RHT matrix must be 16x16");
TensorWrapper rht_matrix_tensor(rht_matrix_buf.untyped_data(), std::vector<size_t>{16, 16},
DType::kBFloat16);
nvte_hadamard_transform_cast_fusion_columnwise(input_tensor.data(), out_transpose.data(),
rht_matrix_tensor.data(), quant_config,
stream);
return ffi_with_cuda_error_check();
}
bool const is_colwise_transposed =
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4;
auto &tmp_shape = is_colwise_transposed ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape); output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf; auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf;
...@@ -154,26 +236,30 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -154,26 +236,30 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1}); std::vector<size_t>{1});
} else { } else {
auto colwise_flatten_axis = flatten_axis;
if (is_colwise_transposed) {
// convert flatten_axis from N layout to T layout
colwise_flatten_axis = tmp_buf->dimensions().size() - flatten_axis;
}
output_tensor.set_columnwise_scale_inv( output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{ std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis), product(tmp_buf->dimensions(), 0, colwise_flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())}); product(tmp_buf->dimensions(), colwise_flatten_axis, tmp_buf->dimensions().size())});
}
} }
if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
} }
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
if (is_dbias) { if (is_dbias) {
NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NVFP4_2D_SCALING,
"DBias quantization is not supported for NVFP4_2D_SCALING as fused dbias API cannot "
"take quant_config as input.");
nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
} else { } else {
nvte_quantize(input_tensor.data(), output_tensor.data(), stream); nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream);
} }
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -184,6 +270,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, ...@@ -184,6 +270,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Arg<Buffer_Type>() // input .Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax .Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // sr_rng_state
.Arg<Buffer_Type>() // colwise amax
.Arg<Buffer_Type>() // rht matrix
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output .Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
...@@ -194,7 +283,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, ...@@ -194,7 +283,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout") .Attr<int64_t>("q_layout")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis"), .Attr<int64_t>("flatten_axis")
.Attr<bool>("stochastic_rounding")
.Attr<bool>("use_rht"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
...@@ -344,7 +435,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -344,7 +435,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
sinv_size = 1; sinv_size = 1;
} else { } else {
const bool is_colwise = false; const bool is_colwise = false;
auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise); auto sinv_shape_i = get_block_scale_shape(scaling_mode, m_i, n, is_colwise);
out_i.set_rowwise_scale_inv(static_cast<void *>(sinv_ptr), sinv_dtype, sinv_shape_i); out_i.set_rowwise_scale_inv(static_cast<void *>(sinv_ptr), sinv_dtype, sinv_shape_i);
sinv_size = product(sinv_shape_i); sinv_size = product(sinv_shape_i);
} }
...@@ -363,7 +454,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -363,7 +454,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
colwise_sinv_size = 1; colwise_sinv_size = 1;
} else { } else {
const bool is_colwise = true; const bool is_colwise = true;
auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise); auto sinv_shape_i = get_block_scale_shape(scaling_mode, m_i, n, is_colwise);
out_i.set_columnwise_scale_inv(static_cast<void *>(colwise_sinv_ptr), sinv_dtype, out_i.set_columnwise_scale_inv(static_cast<void *>(colwise_sinv_ptr), sinv_dtype,
sinv_shape_i); sinv_shape_i);
colwise_sinv_size = product(sinv_shape_i); colwise_sinv_size = product(sinv_shape_i);
......
...@@ -11,10 +11,12 @@ customizable contracting dimensions for flexible tensor operations. ...@@ -11,10 +11,12 @@ customizable contracting dimensions for flexible tensor operations.
from typing import Tuple, Sequence from typing import Tuple, Sequence
from functools import partial from functools import partial
import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .cpp_extensions.amax import AmaxScope
from .quantize import ( from .quantize import (
ScaledTensorFactory, ScaledTensorFactory,
ScalingMode, ScalingMode,
...@@ -61,8 +63,11 @@ def dense( ...@@ -61,8 +63,11 @@ def dense(
kernel: jnp.ndarray, kernel: jnp.ndarray,
bias: jnp.ndarray = None, bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
transpose_batch_sequence: bool = False,
input_axes: Tuple[str, ...] = None, input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
output_axes: Tuple[str, ...] = None,
collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
"""Perform dense layer transformation with optional quantization. """Perform dense layer transformation with optional quantization.
...@@ -76,11 +81,19 @@ def dense( ...@@ -76,11 +81,19 @@ def dense(
kernel: Weight matrix for the dense layer transformation kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract contracting_dims: Tuple of sequences specifying which dimensions to contract
transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
output_axes: Logical axes for sharding the output
collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
if transpose_batch_sequence:
warnings.warn("transpose_batch_sequence is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled(): if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype input_dtype = x.dtype
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
...@@ -90,29 +103,28 @@ def dense( ...@@ -90,29 +103,28 @@ def dense(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
transpose_batch_sequence,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes,
collective_op_set,
quantizer_set, quantizer_set,
) )
return output return output
@partial( @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8))
jax.custom_vjp,
nondiff_argnums=(
3,
4,
5,
),
)
def _dense( def _dense(
x, x,
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
transpose_batch_sequence,
input_axes, input_axes,
kernel_axes, kernel_axes,
quantizer_set, output_axes,
collective_op_set,
quantizer_set, # need to be a diff_arg for DelayedScaling state management
): ):
"""Internal implementation of dense layer transformation with custom VJP. """Internal implementation of dense layer transformation with custom VJP.
...@@ -124,8 +136,11 @@ def _dense( ...@@ -124,8 +136,11 @@ def _dense(
kernel: Weight matrix kernel: Weight matrix
bias: Optional bias tensor bias: Optional bias tensor
contracting_dims: Contracting dimensions specification contracting_dims: Contracting dimensions specification
transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
output_axes: Logical axes for sharding the output_axes
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
...@@ -136,8 +151,11 @@ def _dense( ...@@ -136,8 +151,11 @@ def _dense(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
transpose_batch_sequence,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes,
collective_op_set,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -148,8 +166,11 @@ def _dense_fwd_rule( ...@@ -148,8 +166,11 @@ def _dense_fwd_rule(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
transpose_batch_sequence,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes,
collective_op_set,
quantizer_set, quantizer_set,
): ):
"""Forward pass rule for dense layer transformation. """Forward pass rule for dense layer transformation.
...@@ -175,6 +196,8 @@ def _dense_fwd_rule( ...@@ -175,6 +196,8 @@ def _dense_fwd_rule(
x, x,
flatten_axis=flatten_axis_x, flatten_axis=flatten_axis_x,
quantizer=quantizer_set.x, quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
...@@ -182,6 +205,7 @@ def _dense_fwd_rule( ...@@ -182,6 +205,7 @@ def _dense_fwd_rule(
kernel, kernel,
flatten_axis=flatten_axis_k, flatten_axis=flatten_axis_k,
quantizer=quantizer_set.kernel, quantizer=quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
) )
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
...@@ -191,9 +215,12 @@ def _dense_fwd_rule( ...@@ -191,9 +215,12 @@ def _dense_fwd_rule(
casted_x.get_tensor(usage=TensorUsage.LHS), casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS), casted_kernel.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=transpose_batch_sequence,
bias=bias if not tex.gemm_uses_jax_dot() else None, bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set.forward,
) )
output = with_sharding_constraint_by_logical_axes(output, output_axes)
if use_bias and tex.gemm_uses_jax_dot(): if use_bias and tex.gemm_uses_jax_dot():
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
...@@ -212,8 +239,15 @@ def _dense_fwd_rule( ...@@ -212,8 +239,15 @@ def _dense_fwd_rule(
def _dense_bwd_rule( def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad contracting_dims,
): # pylint: disable=unused-argument transpose_batch_sequence,
input_axes,
kernel_axes,
output_axes,
collective_op_set,
ctx,
grad,
):
"""Backward pass rule for dense layer transformation. """Backward pass rule for dense layer transformation.
Returns: Returns:
...@@ -228,6 +262,7 @@ def _dense_bwd_rule( ...@@ -228,6 +262,7 @@ def _dense_bwd_rule(
quantizer_set, quantizer_set,
flatten_axis_k, flatten_axis_k,
) = ctx ) = ctx
grad = with_sharding_constraint_by_logical_axes(grad, output_axes)
fwd_x_contracting_dims, fwd_k_contracting_dims = map( fwd_x_contracting_dims, fwd_k_contracting_dims = map(
tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
...@@ -238,6 +273,8 @@ def _dense_bwd_rule( ...@@ -238,6 +273,8 @@ def _dense_bwd_rule(
is_dbias=use_bias, is_dbias=use_bias,
flatten_axis=flatten_axis_k, flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad, quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# GEMM NT # GEMM NT
...@@ -254,8 +291,9 @@ def _dense_bwd_rule( ...@@ -254,8 +291,9 @@ def _dense_bwd_rule(
casted_grad.get_tensor(usage=TensorUsage.LHS), casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs, casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim), contracting_dims=(g_contracting_dim, k_contracting_dim),
transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set.backward,
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN # GEMM TN
# x_non_contracting_dims # x_non_contracting_dims
...@@ -267,7 +305,10 @@ def _dense_bwd_rule( ...@@ -267,7 +305,10 @@ def _dense_bwd_rule(
casted_x_lhs, casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS), casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim), contracting_dims=(x_contracting_dim, g_contracting_dim),
transpose_batch_sequence=transpose_batch_sequence,
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
return dgrad, wgrad, dbias, quantizer_set return dgrad, wgrad, dbias, quantizer_set
......
...@@ -15,7 +15,6 @@ from jax import lax ...@@ -15,7 +15,6 @@ from jax import lax
from jax import random as jax_random from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from transformer_engine.common import recipe
from ..dense import dense from ..dense import dense
...@@ -35,10 +34,9 @@ from ..cpp_extensions import ( ...@@ -35,10 +34,9 @@ from ..cpp_extensions import (
from ..quantize import ( from ..quantize import (
QuantizerFactory, QuantizerFactory,
get_quantize_config, get_quantize_config,
QuantizeMeta,
QuantizeMetaSet, QuantizeMetaSet,
ScalingMode,
TensorSource, TensorSource,
get_quantize_config_with_recipe,
) )
PRNGKey = Any PRNGKey = Any
...@@ -353,40 +351,32 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -353,40 +351,32 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
Generate a set of FP8 meta for a GEMM. Generate a set of FP8 meta for a GEMM.
""" """
def generate_quantize_meta(quantizer_name: str):
collection_name = ( collection_name = (
variable_collection variable_collection
if variable_collection is not None if variable_collection is not None
else get_quantize_config().COLLECTION_NAME else get_quantize_config().COLLECTION_NAME
) )
scale = self.variable(
collection_name, if fp8_recipe is None:
f"{quantizer_name}{postfix}_scale", quantize_config = get_quantize_config()
jnp.ones,
(1,),
jnp.float32,
).value
amax_history = self.variable(
collection_name,
f"{quantizer_name}{postfix}_amax_history",
jnp.zeros,
(get_quantize_config().AMAX_HISTORY_LEN,),
jnp.float32,
).value
return QuantizeMeta(scale=scale, amax_history=amax_history)
if get_quantize_config().get_scaling_mode(
TensorSource.X
) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling):
x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad")
quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
kwargs = {"quantize_meta_set": quantize_meta_set}
else: else:
kwargs = {} quantize_config = get_quantize_config_with_recipe(fp8_recipe)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs) x_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.X, "x"
)
kernel_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.KERNEL, "kernel"
)
grad_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.DGRAD, "grad"
)
quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
quantizer_set = QuantizerFactory.create_set(
fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set
)
return quantizer_set return quantizer_set
...@@ -432,6 +422,8 @@ class DenseGeneral(TransformerEngineBase): ...@@ -432,6 +422,8 @@ class DenseGeneral(TransformerEngineBase):
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
...@@ -446,6 +438,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -446,6 +438,7 @@ class DenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
input_axes: Tuple[str, ...] = () input_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -512,6 +505,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -512,6 +505,7 @@ class DenseGeneral(TransformerEngineBase):
input_axes=self.input_axes, input_axes=self.input_axes,
kernel_axes=self.kernel_axes, kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set, quantizer_set=quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
...@@ -632,6 +626,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -632,6 +626,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
depth_scaling: float, default = None depth_scaling: float, default = None
The factor to scale the output from `DenseGeneral`. It should be a float The factor to scale the output from `DenseGeneral`. It should be a float
value or None. When None is set, then no scaling is applied. value or None. When None is set, then no scaling is applied.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
...@@ -657,6 +653,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -657,6 +653,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
layernorm_input_axes: Tuple[str, ...] = None layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None depth_scaling: float = None
transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -768,6 +765,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -768,6 +765,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
dot_input_axes=self.dot_input_axes, dot_input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes, kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set, quantizer_set=quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
...@@ -775,6 +773,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -775,6 +773,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
y, y,
kernel, kernel,
contracting_dims=(axis, contract_ind), contracting_dims=(axis, contract_ind),
transpose_batch_sequence=self.transpose_batch_sequence,
input_axes=self.dot_input_axes, input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes, kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set, quantizer_set=quantizer_set,
...@@ -898,6 +897,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -898,6 +897,10 @@ class LayerNormMLP(TransformerEngineBase):
activations: Sequence[Union[str, Callable]], default = ('relu',) activations: Sequence[Union[str, Callable]], default = ('relu',)
The sequence of activation functions to apply after the first dense layer transformation. The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
activation_params: dict, default = None
The parameters needed(if any) by the activation functions specified in :attr:`activations`.
At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS
need additional parameters.
intermediate_dropout_rng_name: str, default = 'dropout' intermediate_dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
intermediate_dropout_rate: float, default = 0.1 intermediate_dropout_rate: float, default = 0.1
...@@ -936,6 +939,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -936,6 +939,8 @@ class LayerNormMLP(TransformerEngineBase):
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
""" """
intermediate_dim: int = 2048 intermediate_dim: int = 2048
...@@ -956,6 +961,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -956,6 +961,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_axes_2: Tuple[str, ...] = ("embed",) bias_axes_2: Tuple[str, ...] = ("embed",)
return_layernorm_output: bool = True return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ("relu",) activations: Sequence[Union[str, Callable]] = ("relu",)
activation_params: dict = None
intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rng_name: str = "dropout"
intermediate_dropout_rate: float = 0.1 intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = () intermediate_hidden_dropout_dims: Sequence[int] = ()
...@@ -969,6 +975,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -969,6 +975,7 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes: Tuple[str, ...] = None dot_2_input_axes: Tuple[str, ...] = None
ffn1_ckpt_name: str = "ffn1" ffn1_ckpt_name: str = "ffn1"
ffn2_ckpt_name: str = "ffn2" ffn2_ckpt_name: str = "ffn2"
transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -1023,6 +1030,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1023,6 +1030,7 @@ class LayerNormMLP(TransformerEngineBase):
("relu", "linear"), ("relu", "linear"),
("quick_gelu", "linear"), ("quick_gelu", "linear"),
("squared_relu", "linear"), ("squared_relu", "linear"),
("clamped_silu", "clamped_linear"),
] ]
act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)] act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
normalized_acts = [] normalized_acts = []
...@@ -1031,7 +1039,9 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1031,7 +1039,9 @@ class LayerNormMLP(TransformerEngineBase):
return False return False
normalized_acts.append(act.lower()) normalized_acts.append(act.lower())
normalized_acts = tuple( normalized_acts = tuple(
reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts reversed(normalized_acts)
if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear")
else normalized_acts
) )
is_act_implemented = normalized_acts in (gated_act_pool + act_pool) is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
...@@ -1150,7 +1160,9 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1150,7 +1160,9 @@ class LayerNormMLP(TransformerEngineBase):
ffn1_ckpt_name=self.ffn1_ckpt_name, ffn1_ckpt_name=self.ffn1_ckpt_name,
ffn2_ckpt_name=self.ffn2_ckpt_name, ffn2_ckpt_name=self.ffn2_ckpt_name,
activation_type=normalized_acts, activation_type=normalized_acts,
activation_params=self.activation_params,
quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
transpose_batch_sequence=self.transpose_batch_sequence,
) )
out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
...@@ -1169,6 +1181,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1169,6 +1181,7 @@ class LayerNormMLP(TransformerEngineBase):
dot_input_axes=self.dot_1_input_axes, dot_input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1, kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes) y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
...@@ -1179,6 +1192,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1179,6 +1192,7 @@ class LayerNormMLP(TransformerEngineBase):
input_axes=self.dot_1_input_axes, input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1, kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
...@@ -1251,6 +1265,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1251,6 +1265,7 @@ class LayerNormMLP(TransformerEngineBase):
input_axes=self.dot_2_input_axes, input_axes=self.dot_2_input_axes,
kernel_axes=self.kernel_axes_2, kernel_axes=self.kernel_axes_2,
quantizer_set=ffn2_quantizer_set, quantizer_set=ffn2_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
...@@ -1287,4 +1302,4 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1287,4 +1302,4 @@ class LayerNormMLP(TransformerEngineBase):
out = checkpoint_name(out, self.ffn2_ckpt_name) out = checkpoint_name(out, self.ffn2_ckpt_name)
assert out.dtype == input_dtype assert out.dtype == input_dtype
return out, ln_output # Output, layner_norm_output return out, ln_output # Output, layer_norm_output
...@@ -53,6 +53,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[ ...@@ -53,6 +53,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[
return drop_path_shape return drop_path_shape
# TODO(Phuong): move this function to sharding.py
def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
""" """
Extend the given Flax logical axis rules with the predefined TransformerLayer's Extend the given Flax logical axis rules with the predefined TransformerLayer's
...@@ -65,7 +66,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: ...@@ -65,7 +66,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
for 1D-sharding tensor parallelism. for 1D-sharding tensor parallelism.
.. warning:: .. warning::
Please make sure ShardingResource is set via fp8_autocast before calling this function. Please make sure ShardingResource is set via autocast before calling this function.
.. note:: .. note::
This function is only needed when using TransformerLayer. For other modules, such as This function is only needed when using TransformerLayer. For other modules, such as
...@@ -1206,6 +1207,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1206,6 +1207,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="qkv", name="qkv",
dtype=self.dtype, dtype=self.dtype,
)(inputs_q) )(inputs_q)
...@@ -1233,6 +1235,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1233,6 +1235,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel_init=query_init, kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="query", name="query",
)(inputs_q) )(inputs_q)
...@@ -1251,6 +1254,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1251,6 +1254,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
enable_low_rank_adaptation=lora_scope.qkv_proj, enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
transpose_batch_sequence=self.transpose_batch_sequence,
name="kv", name="kv",
dtype=self.dtype, dtype=self.dtype,
)(inputs_kv) )(inputs_kv)
...@@ -1291,6 +1295,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1291,6 +1295,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel_init=query_init, kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="query", name="query",
)(inputs_q) )(inputs_q)
...@@ -1631,6 +1636,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1631,6 +1636,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mlp_activations: Sequence[str], default = ('relu', ) mlp_activations: Sequence[str], default = ('relu', )
The sequence of activation functions to apply after the first linear transformation. The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
mlp_activation_params: dict = None
This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment
ClampedSwiglu is the only activation that requires parameters.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
If set to False, the layer will not learn additive biases. If set to False, the layer will not learn additive biases.
...@@ -1751,6 +1759,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1751,6 +1759,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mha_kernel_init: Initializer = None mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None mlp_kernel_init: Initializer = None
mlp_activations: Sequence[str] = ("relu",) mlp_activations: Sequence[str] = ("relu",)
mlp_activation_params: dict = None
use_bias: bool = False use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
...@@ -2045,6 +2054,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -2045,6 +2054,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size, intermediate_dim=self.mlp_hidden_size,
activations=self.mlp_activations, activations=self.mlp_activations,
activation_params=self.mlp_activation_params,
intermediate_dropout_rng_name=self.dropout_rng_name, intermediate_dropout_rng_name=self.dropout_rng_name,
intermediate_dropout_rate=self.intermediate_dropout, intermediate_dropout_rate=self.intermediate_dropout,
intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
...@@ -2064,6 +2074,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -2064,6 +2074,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES), layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES), dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES), dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
transpose_batch_sequence=self.transpose_batch_sequence,
name="mlp", name="mlp",
)(mlp_input, deterministic=deterministic) )(mlp_input, deterministic=deterministic)
......
...@@ -16,6 +16,7 @@ import jax ...@@ -16,6 +16,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .cpp_extensions.amax import AmaxScope
from .quantize import ( from .quantize import (
QuantizerSet, QuantizerSet,
...@@ -35,6 +36,7 @@ def layernorm_dense( ...@@ -35,6 +36,7 @@ def layernorm_dense(
norm_type: str = "layernorm", norm_type: str = "layernorm",
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
transpose_batch_sequence: bool = False,
layernorm_input_axes: Tuple[str, ...] = None, layernorm_input_axes: Tuple[str, ...] = None,
dot_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
...@@ -55,6 +57,7 @@ def layernorm_dense( ...@@ -55,6 +57,7 @@ def layernorm_dense(
norm_type: Type of normalization ("layernorm" or "rmsnorm") norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization epsilon: Small constant for numerical stability in normalization
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
layernorm_input_axes: Logical axes for sharding the layernorm input layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
...@@ -83,6 +86,7 @@ def layernorm_dense( ...@@ -83,6 +86,7 @@ def layernorm_dense(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
...@@ -100,6 +104,7 @@ def layernorm_dense( ...@@ -100,6 +104,7 @@ def layernorm_dense(
8, 8,
9, 9,
10, 10,
11,
), ),
) )
def _layernorm_dense( def _layernorm_dense(
...@@ -111,6 +116,7 @@ def _layernorm_dense( ...@@ -111,6 +116,7 @@ def _layernorm_dense(
norm_type: str, norm_type: str,
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
transpose_batch_sequence: bool,
layernorm_input_axes: Tuple[str, ...], layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...],
...@@ -131,6 +137,7 @@ def _layernorm_dense( ...@@ -131,6 +137,7 @@ def _layernorm_dense(
norm_type: Type of normalization norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability epsilon: Small constant for numerical stability
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
layernorm_input_axes: Logical axes for layernorm sharding layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding dot_input_axes: Logical axes for matrix multiplication sharding
quantizer_set: Set of quantizers quantizer_set: Set of quantizers
...@@ -147,6 +154,7 @@ def _layernorm_dense( ...@@ -147,6 +154,7 @@ def _layernorm_dense(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
...@@ -164,6 +172,7 @@ def _layernorm_dense_fwd_rule( ...@@ -164,6 +172,7 @@ def _layernorm_dense_fwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
...@@ -194,6 +203,8 @@ def _layernorm_dense_fwd_rule( ...@@ -194,6 +203,8 @@ def _layernorm_dense_fwd_rule(
epsilon, epsilon,
norm_type, norm_type,
quantizer=quantizer_set.x, quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
...@@ -203,6 +214,8 @@ def _layernorm_dense_fwd_rule( ...@@ -203,6 +214,8 @@ def _layernorm_dense_fwd_rule(
kernel, kernel,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
quantizer=quantizer_set.kernel, quantizer=quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
) )
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
...@@ -213,6 +226,7 @@ def _layernorm_dense_fwd_rule( ...@@ -213,6 +226,7 @@ def _layernorm_dense_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_tensor(TensorUsage.RHS), casted_kernel.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=transpose_batch_sequence,
bias=bias if not tex.gemm_uses_jax_dot() else None, bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
) )
...@@ -245,6 +259,7 @@ def _layernorm_dense_bwd_rule( ...@@ -245,6 +259,7 @@ def _layernorm_dense_bwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
...@@ -285,6 +300,8 @@ def _layernorm_dense_bwd_rule( ...@@ -285,6 +300,8 @@ def _layernorm_dense_bwd_rule(
is_dbias=use_bias, is_dbias=use_bias,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
quantizer=quantizer_set.dgrad, quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
...@@ -301,6 +318,7 @@ def _layernorm_dense_bwd_rule( ...@@ -301,6 +318,7 @@ def _layernorm_dense_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel, casted_kernel,
contracting_dims=(g_constracting_dim, k_constracting_dim), contracting_dims=(g_constracting_dim, k_constracting_dim),
transpose_batch_sequence=transpose_batch_sequence,
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
...@@ -314,6 +332,7 @@ def _layernorm_dense_bwd_rule( ...@@ -314,6 +332,7 @@ def _layernorm_dense_bwd_rule(
casted_ln_out, casted_ln_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_constracting_dim, g_constracting_dim), contracting_dims=(x_constracting_dim, g_constracting_dim),
transpose_batch_sequence=transpose_batch_sequence,
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
...@@ -21,6 +21,7 @@ import jax.numpy as jnp ...@@ -21,6 +21,7 @@ import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .cpp_extensions.amax import AmaxScope
from .layernorm import canonicalize_norm_type from .layernorm import canonicalize_norm_type
from .quantize import ( from .quantize import (
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
...@@ -40,6 +41,7 @@ def layernorm_mlp( ...@@ -40,6 +41,7 @@ def layernorm_mlp(
norm_type: str, norm_type: str,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
transpose_batch_sequence: bool = False,
norm_input_axes: Tuple[str, ...] = None, norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None,
...@@ -48,6 +50,11 @@ def layernorm_mlp( ...@@ -48,6 +50,11 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1", ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2", ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
activation_params: dict = None,
collective_op_sets: Tuple[tex.CollectiveOpSet] = (
tex.noop_collective_op_set,
tex.noop_collective_op_set,
),
quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by MLP block. """Apply layer normalization followed by MLP block.
...@@ -71,6 +78,7 @@ def layernorm_mlp( ...@@ -71,6 +78,7 @@ def layernorm_mlp(
norm_type: Type of normalization ("layernorm" or "rmsnorm") norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization epsilon: Small constant for numerical stability in normalization
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for sharding the layernorm input norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication
...@@ -79,6 +87,7 @@ def layernorm_mlp( ...@@ -79,6 +87,7 @@ def layernorm_mlp(
ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation activation_type: Activation function(s) to apply after the first dense layer transformation
collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns: Returns:
...@@ -121,6 +130,7 @@ def layernorm_mlp( ...@@ -121,6 +130,7 @@ def layernorm_mlp(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -129,12 +139,14 @@ def layernorm_mlp( ...@@ -129,12 +139,14 @@ def layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
activation_params,
collective_op_sets,
quantizer_sets, quantizer_sets,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20))
def _layernorm_mlp( def _layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -146,6 +158,7 @@ def _layernorm_mlp( ...@@ -146,6 +158,7 @@ def _layernorm_mlp(
norm_type: str, norm_type: str,
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
transpose_batch_sequence: bool,
norm_input_axes: Tuple[str, ...], norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
...@@ -154,6 +167,8 @@ def _layernorm_mlp( ...@@ -154,6 +167,8 @@ def _layernorm_mlp(
ffn1_ckpt_name: str, ffn1_ckpt_name: str,
ffn2_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
activation_params: dict,
collective_op_sets: Tuple[tex.CollectiveOpSet],
quantizer_sets, quantizer_sets,
): ):
"""Internal implementation of layernorm_mlp with custom VJP. """Internal implementation of layernorm_mlp with custom VJP.
...@@ -173,12 +188,16 @@ def _layernorm_mlp( ...@@ -173,12 +188,16 @@ def _layernorm_mlp(
norm_type: Type of normalization norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability epsilon: Small constant for numerical stability
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for layernorm sharding norm_input_axes: Logical axes for layernorm sharding
dot_1_input_axes: Logical axes for first matrix multiplication sharding dot_1_input_axes: Logical axes for first matrix multiplication sharding
dot_2_input_axes: Logical axes for second matrix multiplication sharding dot_2_input_axes: Logical axes for second matrix multiplication sharding
kernel_1_axes: Logical axes for first weight matrix sharding
kernel_2_axes: Logical axes for second weight matrix sharding
ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s) activation_type: Activation function(s)
collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
quantizer_sets: Tuple of quantizer sets quantizer_sets: Tuple of quantizer sets
Returns: Returns:
...@@ -195,6 +214,7 @@ def _layernorm_mlp( ...@@ -195,6 +214,7 @@ def _layernorm_mlp(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -203,6 +223,8 @@ def _layernorm_mlp( ...@@ -203,6 +223,8 @@ def _layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
activation_params,
collective_op_sets,
quantizer_sets, quantizer_sets,
) )
return output return output
...@@ -219,6 +241,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -219,6 +241,7 @@ def _layernorm_mlp_fwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -227,6 +250,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -227,6 +250,8 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
activation_params,
collective_op_sets,
quantizer_sets, quantizer_sets,
): ):
"""Forward pass rule for layernorm_mlp. """Forward pass rule for layernorm_mlp.
...@@ -246,6 +271,10 @@ def _layernorm_mlp_fwd_rule( ...@@ -246,6 +271,10 @@ def _layernorm_mlp_fwd_rule(
del kernel_1_axes, kernel_2_axes del kernel_1_axes, kernel_2_axes
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
collective_op_set_1, collective_op_set_2 = collective_op_sets
assert not collective_op_set_1.forward.is_reduce_scatter
assert not collective_op_set_2.forward.is_all_gather
# x should be in shape of (batch..., hidden) # x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len, intermediate) # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
...@@ -272,6 +301,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -272,6 +301,8 @@ def _layernorm_mlp_fwd_rule(
epsilon, epsilon,
norm_type, norm_type,
quantizer=ffn1_quantizer_set.x, quantizer=ffn1_quantizer_set.x,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
...@@ -279,6 +310,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -279,6 +310,8 @@ def _layernorm_mlp_fwd_rule(
kernel_1, kernel_1,
flatten_axis=-2, flatten_axis=-2,
quantizer=ffn1_quantizer_set.kernel, quantizer=ffn1_quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# NN GEMM # NN GEMM
...@@ -287,8 +320,10 @@ def _layernorm_mlp_fwd_rule( ...@@ -287,8 +320,10 @@ def _layernorm_mlp_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS), casted_kernel_1.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=transpose_batch_sequence,
bias=bias_1 if not tex.gemm_uses_jax_dot() else None, bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_1.forward,
) )
if use_bias_1 and tex.gemm_uses_jax_dot(): if use_bias_1 and tex.gemm_uses_jax_dot():
...@@ -310,6 +345,13 @@ def _layernorm_mlp_fwd_rule( ...@@ -310,6 +345,13 @@ def _layernorm_mlp_fwd_rule(
dot_1_output, dot_1_output,
activation_type, activation_type,
quantizer=ffn2_quantizer_set.x, quantizer=ffn2_quantizer_set.x,
act_params=(
tex.activation.ActivationParams.create(activation_type, **activation_params)
if activation_params
else None
),
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
...@@ -317,6 +359,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -317,6 +359,8 @@ def _layernorm_mlp_fwd_rule(
casted_kernel_2 = tex.quantize( casted_kernel_2 = tex.quantize(
kernel_2, kernel_2,
quantizer=ffn2_quantizer_set.kernel, quantizer=ffn2_quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# NN GEMM # NN GEMM
...@@ -325,8 +369,10 @@ def _layernorm_mlp_fwd_rule( ...@@ -325,8 +369,10 @@ def _layernorm_mlp_fwd_rule(
casted_act_out.get_tensor(TensorUsage.LHS), casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS), casted_kernel_2.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=transpose_batch_sequence,
bias=bias_2 if not tex.gemm_uses_jax_dot() else None, bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_2.forward,
) )
if use_bias_2 and tex.gemm_uses_jax_dot(): if use_bias_2 and tex.gemm_uses_jax_dot():
...@@ -334,6 +380,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -334,6 +380,8 @@ def _layernorm_mlp_fwd_rule(
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
# sharding of outputs should be the same as dot_1's input
dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_1_input_axes)
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = ( ctx = (
...@@ -363,6 +411,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -363,6 +411,7 @@ def _layernorm_mlp_bwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -371,6 +420,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -371,6 +420,8 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
activation_params,
collective_op_sets,
ctx, ctx,
grad, grad,
): ):
...@@ -409,6 +460,10 @@ def _layernorm_mlp_bwd_rule( ...@@ -409,6 +460,10 @@ def _layernorm_mlp_bwd_rule(
) = ctx ) = ctx
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
collective_op_set_1, collective_op_set_2 = collective_op_sets
assert not collective_op_set_1.backward.is_all_gather
assert not collective_op_set_2.backward.is_reduce_scatter
# Since the sharding of outputs should be the same as dot_1's input # Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
...@@ -417,6 +472,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -417,6 +472,8 @@ def _layernorm_mlp_bwd_rule(
grad, grad,
is_dbias=use_bias_2, is_dbias=use_bias_2,
quantizer=ffn1_quantizer_set.dgrad, quantizer=ffn1_quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
...@@ -434,6 +491,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -434,6 +491,8 @@ def _layernorm_mlp_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2, casted_kernel_2,
contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set_2.backward,
) )
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
...@@ -448,6 +507,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -448,6 +507,7 @@ def _layernorm_mlp_bwd_rule(
casted_act_out, casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=transpose_batch_sequence,
) )
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
...@@ -457,6 +517,13 @@ def _layernorm_mlp_bwd_rule( ...@@ -457,6 +517,13 @@ def _layernorm_mlp_bwd_rule(
activation_type=activation_type, activation_type=activation_type,
is_dbias=use_bias_1, is_dbias=use_bias_1,
quantizer=ffn2_quantizer_set.dgrad, quantizer=ffn2_quantizer_set.dgrad,
act_params=(
tex.activation.ActivationParams.create(activation_type, **activation_params)
if activation_params
else None
),
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
...@@ -474,6 +541,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -474,6 +541,8 @@ def _layernorm_mlp_bwd_rule(
casted_dact_out.get_tensor(TensorUsage.LHS), casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1, casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set_1.backward,
) )
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
...@@ -484,6 +553,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -484,6 +553,7 @@ def _layernorm_mlp_bwd_rule(
casted_ln_out, casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS), casted_dact_out.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=transpose_batch_sequence,
) )
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
...@@ -14,5 +14,6 @@ from .quantizer import * ...@@ -14,5 +14,6 @@ from .quantizer import *
from .dequantizer import * from .dequantizer import *
from .scaling_modes import * from .scaling_modes import *
from .metadata import * from .metadata import *
from .hadamard import *
from .helper import * from .helper import *
from .device_utils import * from .device_utils import *
...@@ -15,6 +15,8 @@ import jax ...@@ -15,6 +15,8 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .hadamard import apply_rht, should_use_rht
__all__ = ["ScalingModeToDequantizerMap"] __all__ = ["ScalingModeToDequantizerMap"]
...@@ -119,7 +121,7 @@ class BlockScaleDequantizer(Dequantizer): ...@@ -119,7 +121,7 @@ class BlockScaleDequantizer(Dequantizer):
0 < flatten_axis < len(data_shape) 0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaling_mode.get_scale_shape( scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
) )
data = data.reshape( data = data.reshape(
...@@ -161,10 +163,99 @@ class BlockScaleDequantizer(Dequantizer): ...@@ -161,10 +163,99 @@ class BlockScaleDequantizer(Dequantizer):
) )
class NVFP4Dequantizer(Dequantizer):
"""NVFP4 Dequantizer Class.
This class provides static methods for dequantizing tensors that have been
quantized using NVFP4 scaling modes.
"""
@staticmethod
def _dequantize_func(data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis):
"""Dequantize a tensor using block scaling.
Args:
data: The quantized tensor data
scale_inv: The inverse scaling factors
amax: The maximum absolute value of the tensor
dq_dtype: The data type for dequantized values
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
Returns:
The dequantized tensor
"""
DATA_DTYPE_MAX = jnp.finfo(data.dtype).max.astype(jnp.float32)
SCALE_DTYPE_MAX = jnp.finfo(scale_inv.dtype).max.astype(jnp.float32)
tensor_scale_inv = amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX)
data = data.astype(jnp.float32)
scale_inv = scale_inv.astype(jnp.float32) * tensor_scale_inv
data_layout = "T" if is_colwise else "N"
data_shape = data.shape
flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaling_mode.get_scale_shape(
data_shape,
data_layout=data_layout,
is_colwise=is_colwise,
is_padded=False,
# expect the flatten_axis wrt the N layout
flatten_axis=flatten_axis if data_layout == "N" else len(data_shape) - flatten_axis,
broadcast_2d_scale_shape_to_1d=True,
)
data = data.reshape(
*data_shape[: flatten_axis - 1],
scale_shape[flatten_axis - 1],
int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*data_shape[flatten_axis:-1],
scale_shape[-1],
int(data_shape[-1] / scale_shape[-1]),
)
scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1))
out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape)
# Apply inverse of RHT if needed
use_rht = should_use_rht(scaling_mode, is_colwise=is_colwise)
if use_rht:
out = apply_rht(out, inverse=True)
return out
@staticmethod
def dequantize(scaled_tensor):
"""Dequantize a tensor using block scaling.
Args:
scaled_tensor: The quantized tensor to dequantize
Returns:
The dequantized tensor
"""
return NVFP4Dequantizer._dequantize_func(
scaled_tensor.data,
scaled_tensor.scale_inv,
scaled_tensor.amax,
scaled_tensor.dq_dtype,
scaled_tensor.scaling_mode,
scaled_tensor.is_colwise,
scaled_tensor.flatten_axis,
)
ScalingModeToDequantizerMap = { ScalingModeToDequantizerMap = {
ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer,
ScalingMode.NVFP4_1D_SCALING: NVFP4Dequantizer,
ScalingMode.NVFP4_2D_SCALING: NVFP4Dequantizer,
ScalingMode.NO_SCALING: NoopDequantizer, ScalingMode.NO_SCALING: NoopDequantizer,
} }
...@@ -210,13 +301,13 @@ def _grouped_dequantize(grouped_scaled_tensor): ...@@ -210,13 +301,13 @@ def _grouped_dequantize(grouped_scaled_tensor):
) )
padded_scale_shape_i = scaling_mode.get_scale_shape( padded_scale_shape_i = scaling_mode.get_scale_shape(
data_shape_i, data_shape_i,
grouped_scaled_tensor.is_colwise, is_colwise=grouped_scaled_tensor.is_colwise,
is_padded=True, is_padded=True,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
unpadded_scale_shape_i = scaling_mode.get_scale_shape( unpadded_scale_shape_i = scaling_mode.get_scale_shape(
data_shape_i, data_shape_i,
grouped_scaled_tensor.is_colwise, is_colwise=grouped_scaled_tensor.is_colwise,
is_padded=False, is_padded=False,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Randomized Hadamard Transform (RHT) utilities for JAX."""
import jax.numpy as jnp
from .scaling_modes import ScalingMode
def should_use_rht(scaling_mode, is_colwise=None, q_layout=None) -> bool:
"""Determine if RHT (Randomized Hadamard Transform) should be used.
Args:
scaling_mode: The scaling mode of the tensor.
is_colwise: Whether the tensor is column-wise. Only one of is_colwise or q_layout should be provided.
q_layout: The quantization layout of the tensor. Only one of is_colwise or q_layout should be provided.
Returns:
bool: True if RHT should be used, False otherwise.
"""
# Delayed import to avoid circular dependencies
from .quantizer import QuantizeLayout
assert (is_colwise is None) != (
q_layout is None
), "Exactly one of is_colwise or q_layout must be provided."
if q_layout is not None:
is_colwise = q_layout in {QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE}
return scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise
def get_wgrad_sign_vector() -> list[int]:
"""Get a fixed sign vector for the RHT used in NVFP4 weight gradient quantization."""
return [1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1]
def get_sign_from_vector(vector: list[int]) -> int:
"""Convert a sign vector to a bitmask integer."""
mask = 0
for i, v in enumerate(vector):
mask |= (v == -1) << i
return mask
def apply_rht(x: jnp.ndarray, inverse=False) -> jnp.ndarray:
"""Apply the Randomized Hadamard Transform (RHT) to the input tensor."""
h = get_rht_matrix()
block_size = 16
if inverse:
h = jnp.linalg.inv(h.astype(jnp.float32)).astype(jnp.bfloat16)
# TODO(jberchtold): These reshapes will break partitioning, fixme
return (x.reshape(-1, block_size) @ h).reshape(x.shape)
def get_rht_matrix() -> jnp.ndarray:
"""Get the Randomized Hadamard Transform (RHT) matrix used in NVFP4 weight gradient quantization.
Returns:
A (16, 16) bfloat16 matrix representing the RHT. This matrix is pre-multiplied by the random sign mask.
"""
import scipy
block_size = 16
h = jnp.array(scipy.linalg.hadamard(block_size))
# Apply the random sign mask
s = jnp.array(get_wgrad_sign_vector(), dtype=jnp.int32)
h = jnp.diag(s) @ h
return (h / jnp.sqrt(block_size)).astype(jnp.bfloat16)
...@@ -7,45 +7,76 @@ Config module for quantization metadata management ...@@ -7,45 +7,76 @@ Config module for quantization metadata management
This module provides configuration and helper functions for managing quantization metadata This module provides configuration and helper functions for managing quantization metadata
in JAX, including support for different scaling modes and datatypes. in JAX, including support for different scaling modes and datatypes.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional, Tuple, Dict, Union, Sequence, Type from typing import Optional, Tuple, Dict, Union, Sequence, Type, List
from functools import reduce from functools import reduce, lru_cache
import operator import operator
from importlib.metadata import version as get_pkg_version
import warnings
from packaging.version import Version as PkgVersion
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version
from transformer_engine.common import recipe from transformer_engine.common.recipe import (
from transformer_engine.jax.sharding import global_shard_guard, MeshResource Recipe,
DelayedScaling,
Format,
MXFP8BlockScaling,
Float8CurrentScaling,
NVFP4BlockScaling,
)
from transformer_engine.jax.sharding import (
global_shard_guard,
MeshResource,
num_of_devices,
get_all_mesh_axes,
with_sharding_constraint,
)
from .metadata import QuantizeMeta
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .. import cpp_extensions as tex
from .device_utils import get_device_compute_capability from .device_utils import get_device_compute_capability
__all__ = [ __all__ = [
"get_quantize_config", "get_quantize_config",
"get_quantize_config_with_recipe",
"autocast",
"fp8_autocast", "fp8_autocast",
"is_fp8_available", "is_fp8_available",
"is_scaling_mode_supported",
"get_supported_scaling_modes",
"get_supported_quantization_recipes",
"update_collections", "update_collections",
"get_delayed_scaling",
"apply_padding_to_scale_inv", "apply_padding_to_scale_inv",
"remove_padding_from_scale_inv", "remove_padding_from_scale_inv",
"NVTE_FP8_COLLECTION_NAME", "NVTE_FP8_COLLECTION_NAME",
"TensorSource", "TensorSource",
] ]
_is_fp8_available = None _is_scaling_mode_supported = None
_reason_for_no_fp8 = "" _reason_for_no_scaling_mode = ""
Collection = Union[Dict, FrozenDict] Collection = Union[Dict, FrozenDict]
NVTE_FP8_COLLECTION_NAME = "fp8_metas" NVTE_FP8_COLLECTION_NAME = "fp8_metas"
@lru_cache(maxsize=None)
def _jax_version_meet_requirement(version: str):
"""
Helper function checking if required JAX version is available
"""
jax_version = PkgVersion(get_pkg_version("jax"))
jax_version_required = PkgVersion(version)
return jax_version >= jax_version_required
def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
"""Check if delayed scaling FP8 is supported on the given GPU architecture. """Check if delayed scaling FP8 is supported on the given GPU architecture.
...@@ -55,8 +86,6 @@ def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: ...@@ -55,8 +86,6 @@ def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
Returns: Returns:
A tuple of (bool, str) indicating support and any error message A tuple of (bool, str) indicating support and any error message
""" """
if gpu_arch >= 90: # hopper and above
return True, ""
if gpu_arch < 89: # pre-ada if gpu_arch < 89: # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution." return False, "Device compute capability 8.9 or higher required for FP8 execution."
if get_cublasLt_version() < 120103: if get_cublasLt_version() < 120103:
...@@ -75,20 +104,31 @@ def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: ...@@ -75,20 +104,31 @@ def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
Returns: Returns:
A tuple of (bool, str) indicating support and any error message A tuple of (bool, str) indicating support and any error message
""" """
if gpu_arch >= 100: # blackwell and above
return True, ""
if gpu_arch < 99: # pre-blackwell if gpu_arch < 99: # pre-blackwell
return False, "Device compute capability 9.9 or higher required for MXFP8 execution." return False, "Device compute capability 9.9 or higher required for MXFP8 execution."
if get_cublasLt_version() < 120800: if get_cublasLt_version() < 120800:
return False, "CublasLt version 12.8.0 or higher required for MXFP8 execution." return False, "CublasLt version 12.8.0 or higher required for MXFP8 execution."
if get_cuda_version() < 12010: if get_cuda_version() < 12080:
return False, "Cuda version 12.8 or higher required for MXFP8 execution." return False, "Cuda version 12.8 or higher required for MXFP8 execution."
if not tex.jax_version_meet_requirement("0.5.3"): if not _jax_version_meet_requirement("0.5.3"):
return False, "Jax version 0.5.3 or higher required for MXFP8 execution." return False, "Jax version 0.5.3 or higher required for MXFP8 execution."
return True, "" return True, ""
def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: def _check_fp4_support(gpu_arch) -> Tuple[bool, str]:
"""Check if FP4 is supported for the given GPU architecture."""
if gpu_arch < 100: # pre-blackwell
return False, "Device compute capability 10.0 or higher required for NVFP4 execution."
if get_cublasLt_version() < 120800:
return False, "CublasLt version 12.8.0 or higher required for NVFP4 execution."
if get_cuda_version() < 12080:
return False, "Cuda version 12.8 or higher required for NVFP4 execution."
if not _jax_version_meet_requirement("0.5.3"):
return False, "Jax version 0.5.3 or higher required for NVFP4 execution."
return True, ""
def _check_scaling_support(scaling_mode: ScalingMode, gpu_id: int) -> Tuple[bool, str]:
"""Check if FP8 is supported for the given scaling mode and GPU. """Check if FP8 is supported for the given scaling mode and GPU.
Args: Args:
...@@ -101,9 +141,35 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: ...@@ -101,9 +141,35 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
gpu_arch = get_device_compute_capability(gpu_id) gpu_arch = get_device_compute_capability(gpu_id)
if scaling_mode.is_tensor_scaling(): if scaling_mode.is_tensor_scaling():
return _check_delayed_scaling_fp8_support(gpu_arch) return _check_delayed_scaling_fp8_support(gpu_arch)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING: if scaling_mode.is_mxfp8_scaling:
return _check_block_scaling_fp8_support(gpu_arch) return _check_block_scaling_fp8_support(gpu_arch)
return (False, "Unsupported scaling_mode!") if scaling_mode.is_nvfp4_scaling:
return _check_fp4_support(gpu_arch)
return (True, "") # NO_SCALING is always supported
def is_scaling_mode_supported(
scaling_mode=ScalingMode.NO_SCALING,
gpu_id=None,
) -> Tuple[bool, str]:
"""Check if the given scaling mode is available for the given GPU."""
if gpu_id is not None:
return _check_scaling_support(scaling_mode, gpu_id)
global _is_scaling_mode_supported, _reason_for_no_scaling_mode
if _is_scaling_mode_supported is None:
_is_scaling_mode_supported = {}
_reason_for_no_scaling_mode = {}
if scaling_mode not in _is_scaling_mode_supported:
_is_scaling_mode_supported[scaling_mode] = True
_reason_for_no_scaling_mode[scaling_mode] = ""
for local_gpu_id in range(len(jax.local_devices())):
ret, msg = _check_scaling_support(scaling_mode, local_gpu_id)
if ret is False:
_is_scaling_mode_supported[scaling_mode] = ret
_reason_for_no_scaling_mode[scaling_mode] = msg
return ret, msg
return _is_scaling_mode_supported[scaling_mode], _reason_for_no_scaling_mode[scaling_mode]
def is_fp8_available( def is_fp8_available(
...@@ -119,29 +185,39 @@ def is_fp8_available( ...@@ -119,29 +185,39 @@ def is_fp8_available(
Returns: Returns:
A tuple of (bool, str) indicating availability and any error message A tuple of (bool, str) indicating availability and any error message
""" """
if gpu_id is not None: warnings.warn(
return _check_fp8_support(scaling_mode, gpu_id) "is_fp8_available is deprecated. Use is_scaling_mode_supported instead.", DeprecationWarning
)
global _is_fp8_available, _reason_for_no_fp8 return is_scaling_mode_supported(scaling_mode=scaling_mode, gpu_id=gpu_id)
if _is_fp8_available is None:
_is_fp8_available = {}
_reason_for_no_fp8 = {} # TODO(Phuong): make the infrastruture to support NO_SCALING
def get_supported_scaling_modes() -> List[ScalingMode]:
if scaling_mode not in _is_fp8_available: """Get all supported quantization scaling modes."""
_is_fp8_available[scaling_mode] = True return [
_reason_for_no_fp8[scaling_mode] = "" scaling_mode
# JAX doesn't provide the local GPU id. for scaling_mode in ScalingMode
for local_gpu_id in range(len(jax.local_devices())): if is_scaling_mode_supported(scaling_mode=scaling_mode)[0]
ret, msg = _check_fp8_support(scaling_mode, local_gpu_id) and scaling_mode != ScalingMode.NO_SCALING
if ret is False: ]
_is_fp8_available[scaling_mode] = ret
_reason_for_no_fp8[scaling_mode] = msg
return ret, msg def get_supported_quantization_recipes() -> List[Recipe]:
"""Get all supported quantization recipes."""
return _is_fp8_available[scaling_mode], _reason_for_no_fp8[scaling_mode] # We don't support all the recipes TE/Common supports yet
# return [get_quantize_config_class(recipe)() for recipe in recipe.Recipe.__subclasses__()]
all_recipes = [
def _format2dtypes(format_: recipe.Format): DelayedScaling(),
Float8CurrentScaling(),
MXFP8BlockScaling(),
NVFP4BlockScaling(),
]
return [
recipe for recipe in all_recipes if get_quantize_config_class(recipe)().is_supported()[0]
]
def _format2dtypes(format_: Format):
"""Convert recipe.Format.dtype to corresponding JAX dtypes. """Convert recipe.Format.dtype to corresponding JAX dtypes.
Args: Args:
...@@ -150,12 +226,14 @@ def _format2dtypes(format_: recipe.Format): ...@@ -150,12 +226,14 @@ def _format2dtypes(format_: recipe.Format):
Returns: Returns:
A tuple of (forward_dtype, backward_dtype) for the given format A tuple of (forward_dtype, backward_dtype) for the given format
""" """
if format_ == recipe.Format.E4M3: if format_ == Format.E4M3:
return jnp.float8_e4m3fn, jnp.float8_e4m3fn return jnp.float8_e4m3fn, jnp.float8_e4m3fn
if format_ == recipe.Format.E5M2: if format_ == Format.E5M2:
return jnp.float8_e5m2, jnp.float8_e5m2 return jnp.float8_e5m2, jnp.float8_e5m2
if format_ == recipe.Format.HYBRID: if format_ == Format.HYBRID:
return jnp.float8_e4m3fn, jnp.float8_e5m2 return jnp.float8_e4m3fn, jnp.float8_e5m2
if format_ == Format.E2M1:
return jnp.float4_e2m1fn, jnp.float4_e2m1fn
return jnp.bfloat16, jnp.bfloat16 return jnp.bfloat16, jnp.bfloat16
...@@ -193,7 +271,6 @@ class BaseQuantizeConfig(ABC): ...@@ -193,7 +271,6 @@ class BaseQuantizeConfig(ABC):
INITIALIZED: Whether the config has been initialized INITIALIZED: Whether the config has been initialized
MARGIN: Margin value for quantization MARGIN: Margin value for quantization
COLLECTION_NAME: Name of the collection for quantization metadata COLLECTION_NAME: Name of the collection for quantization metadata
FP8_FORMAT: FP8 format to use
FWD_DTYPE: Forward pass data type FWD_DTYPE: Forward pass data type
BWD_DTYPE: Backward pass data type BWD_DTYPE: Backward pass data type
FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass
...@@ -207,28 +284,26 @@ class BaseQuantizeConfig(ABC): ...@@ -207,28 +284,26 @@ class BaseQuantizeConfig(ABC):
INITIALIZED = False INITIALIZED = False
MARGIN: float = 0.0 MARGIN: float = 0.0
COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
FP8_FORMAT: recipe.Format = recipe.Format.HYBRID FWD_DTYPE: DType = None
FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0] BWD_DTYPE: DType = None
BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1]
FP8_2X_ACC_FPROP: bool = False FP8_2X_ACC_FPROP: bool = False
FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False
INFERENCE_MODE: bool = False INFERENCE_MODE: bool = False
# DelayedScaling # DelayedScaling
# TODO(Phuong): move these two into DelayedScalingQuantizeConfig
AMAX_HISTORY_LEN: int = 1024 AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize the quantization configuration. """Initialize the quantization configuration from a given recipe.
Args: Args:
fp8_recipe: The FP8 recipe to use for initialization fp8_recipe: The FP8 recipe to use for initialization
""" """
self.INITIALIZED = True self.INITIALIZED = True
self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp8_format)
self.FP8_FORMAT = fp8_recipe.fp8_format
self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(self.FP8_FORMAT)
def is_fp8_enabled(self) -> bool: def is_fp8_enabled(self) -> bool:
"""Check if FP8 quantization is enabled. """Check if FP8 quantization is enabled.
...@@ -249,6 +324,27 @@ class BaseQuantizeConfig(ABC): ...@@ -249,6 +324,27 @@ class BaseQuantizeConfig(ABC):
The scaling mode for the specified usage type. The scaling mode for the specified usage type.
""" """
@abstractmethod
def get_quantize_flax_meta(
self,
module,
collection_name: str,
postfix: str,
tensor_source: TensorSource,
quantizer_name: str,
) -> QuantizeMeta:
"""Get the quantization metadata for a given Flax module.
Args:
module: The Flax module to get metadata for
collection_name: The name of the collection to store metadata in
postfix: Postfix to append to metadata names
tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
quantizer_name: The name of the quantizer within the module
Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
"""
def is_supported(self) -> tuple[bool, str]: def is_supported(self) -> tuple[bool, str]:
"""Check if this QuantizeConfig class is supported on the available devices. """Check if this QuantizeConfig class is supported on the available devices.
...@@ -261,7 +357,7 @@ class BaseQuantizeConfig(ABC): ...@@ -261,7 +357,7 @@ class BaseQuantizeConfig(ABC):
kernel_scaling_mode = self.get_scaling_mode(TensorSource.KERNEL) kernel_scaling_mode = self.get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = self.get_scaling_mode(TensorSource.DGRAD) grad_scaling_mode = self.get_scaling_mode(TensorSource.DGRAD)
for scaling_mode in [x_scaling_mode, kernel_scaling_mode, grad_scaling_mode]: for scaling_mode in [x_scaling_mode, kernel_scaling_mode, grad_scaling_mode]:
is_supported, reason = is_fp8_available(scaling_mode=scaling_mode) is_supported, reason = is_scaling_mode_supported(scaling_mode=scaling_mode)
if not is_supported: if not is_supported:
return is_supported, reason return is_supported, reason
return True, None return True, None
...@@ -270,7 +366,7 @@ class BaseQuantizeConfig(ABC): ...@@ -270,7 +366,7 @@ class BaseQuantizeConfig(ABC):
class NoOpQuantizeConfig(BaseQuantizeConfig): class NoOpQuantizeConfig(BaseQuantizeConfig):
"""Configuration class higher-precision non-quantized operation.""" """Configuration class higher-precision non-quantized operation."""
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize no-op configuration.""" """Initialize no-op configuration."""
raise NotImplementedError( raise NotImplementedError(
"NoOpQuantizeConfig cannot be initialize from a recipe as it represents" "NoOpQuantizeConfig cannot be initialize from a recipe as it represents"
...@@ -281,6 +377,27 @@ class NoOpQuantizeConfig(BaseQuantizeConfig): ...@@ -281,6 +377,27 @@ class NoOpQuantizeConfig(BaseQuantizeConfig):
"""Gets the scaling mode for a specific tensor's usage type.""" """Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.NO_SCALING return ScalingMode.NO_SCALING
def get_quantize_flax_meta(
self,
module,
collection_name: str,
postfix: str,
tensor_source: TensorSource,
quantizer_name: str,
) -> QuantizeMeta:
"""Get the quantization metadata for a given Flax module.
Args:
module: The Flax module to get metadata for
collection_name: The name of the collection to store metadata in
postfix: Postfix to append to metadata names
tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
quantizer_name: The name of the quantizer within the module
Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
"""
return QuantizeMeta()
class DelayedScalingQuantizeConfig(BaseQuantizeConfig): class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for delayed scaling FP8 recipe. """Configuration class for delayed scaling FP8 recipe.
...@@ -289,7 +406,7 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -289,7 +406,7 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
FP8 quantization mode. FP8 quantization mode.
""" """
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize delayed scaling FP8 configuration. """Initialize delayed scaling FP8 configuration.
Args: Args:
...@@ -299,6 +416,7 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -299,6 +416,7 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
AssertionError: If recipe parameters are not supported AssertionError: If recipe parameters are not supported
""" """
super().initialize_from_recipe(fp8_recipe) super().initialize_from_recipe(fp8_recipe)
self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0
assert fp8_recipe.amax_compute_algo in [ assert fp8_recipe.amax_compute_algo in [
"max", "max",
...@@ -323,6 +441,41 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -323,6 +441,41 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
"""Gets the scaling mode for a specific tensor's usage type.""" """Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.DELAYED_TENSOR_SCALING return ScalingMode.DELAYED_TENSOR_SCALING
def get_quantize_flax_meta(
self,
module,
collection_name: str,
postfix: str,
tensor_source: TensorSource,
quantizer_name: str,
) -> QuantizeMeta:
"""Get the quantization metadata for a given Flax module.
Args:
module: The Flax module to get metadata for
collection_name: The name of the collection to store metadata in
postfix: Postfix to append to metadata names
tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
quantizer_name: The name of the quantizer within the module
Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
"""
scale = module.variable(
collection_name,
f"{quantizer_name}{postfix}_scale",
jnp.ones,
(1,),
jnp.float32,
).value
amax_history = module.variable(
collection_name,
f"{quantizer_name}{postfix}_amax_history",
jnp.zeros,
(self.AMAX_HISTORY_LEN,),
jnp.float32,
).value
return QuantizeMeta(scale=scale, amax_history=amax_history)
class CurrentScalingQuantizeConfig(BaseQuantizeConfig): class CurrentScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for current scaling FP8 recipe. """Configuration class for current scaling FP8 recipe.
...@@ -331,7 +484,7 @@ class CurrentScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -331,7 +484,7 @@ class CurrentScalingQuantizeConfig(BaseQuantizeConfig):
FP8 quantization mode. FP8 quantization mode.
""" """
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize current scaling FP8 configuration. """Initialize current scaling FP8 configuration.
Args: Args:
...@@ -344,6 +497,27 @@ class CurrentScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -344,6 +497,27 @@ class CurrentScalingQuantizeConfig(BaseQuantizeConfig):
"""Gets the scaling mode for a specific tensor's usage type.""" """Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.CURRENT_TENSOR_SCALING return ScalingMode.CURRENT_TENSOR_SCALING
def get_quantize_flax_meta(
self,
module,
collection_name: str,
postfix: str,
tensor_source: TensorSource,
quantizer_name: str,
) -> QuantizeMeta:
"""Get the quantization metadata for a given Flax module.
Args:
module: The Flax module to get metadata for
collection_name: The name of the collection to store metadata in
postfix: Postfix to append to metadata names
tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
quantizer_name: The name of the quantizer within the module
Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
"""
return QuantizeMeta()
class BlockScalingQuantizeConfig(BaseQuantizeConfig): class BlockScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for block scaling FP8 recipe. """Configuration class for block scaling FP8 recipe.
...@@ -352,7 +526,7 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -352,7 +526,7 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig):
FP8 quantization mode. FP8 quantization mode.
""" """
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize block scaling FP8 configuration. """Initialize block scaling FP8 configuration.
Args: Args:
...@@ -365,43 +539,137 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -365,43 +539,137 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig):
"""Gets the scaling mode for a specific tensor's usage type.""" """Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.MXFP8_1D_SCALING return ScalingMode.MXFP8_1D_SCALING
def get_quantize_flax_meta(
self,
module,
collection_name: str,
postfix: str,
tensor_source: TensorSource,
quantizer_name: str,
) -> QuantizeMeta:
"""Get the quantization metadata for a given Flax module.
Args:
module: The Flax module to get metadata for
collection_name: The name of the collection to store metadata in
postfix: Postfix to append to metadata names
tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
quantizer_name: The name of the quantizer within the module
Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
"""
return QuantizeMeta()
class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for NVFP4 scaling recipe.
This class provides specific initialization and finalization for NVFP4 scaling quantization mode.
"""
def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize block scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
self.INITIALIZED = True
self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp4_format)
self.AMAX_HISTORY_LEN = 0
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type."""
if tensor_source == TensorSource.KERNEL:
return ScalingMode.NVFP4_2D_SCALING
# for x and grad
return ScalingMode.NVFP4_1D_SCALING
def get_quantize_flax_meta(
self,
module,
collection_name: str,
postfix: str,
tensor_source: TensorSource,
quantizer_name: str,
) -> QuantizeMeta:
"""Get the quantization metadata for a given Flax module.
Args:
module: The Flax module to get metadata for
collection_name: The name of the collection to store metadata in
postfix: Postfix to append to metadata names
tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
quantizer_name: The name of the quantizer within the module
Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
"""
if tensor_source != TensorSource.DGRAD:
# Only DGRAD uses stochastic rounding
return QuantizeMeta()
# TODO(jberchtold): This assumes SR is always enabled for NVFP4. Use flag from recipe to toggle it.
sr_jax_rng = module.make_rng("sr_rng")
# Get a unique key for this quantizer
sr_jax_rng = jax.jit(jax.random.fold_in)(
sr_jax_rng, hash(quantizer_name) % jnp.iinfo(jnp.int32).max
)
# Generate 4 random uint32 values from the JAX PRNG key
sr_jax_rng_state = jax.random.randint(
sr_jax_rng, (num_of_devices(), 4), 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32
).view(jnp.uint32)
sr_jax_rng_state = with_sharding_constraint(
sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None)
)
return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state)
_QUANTIZE_CONFIG = NoOpQuantizeConfig() _QUANTIZE_CONFIG = NoOpQuantizeConfig()
def get_quantize_config(): def get_quantize_config():
"""Global instance of BaseQuantizeConfig set by fp8_autocast context.""" """Global instance of BaseQuantizeConfig set by autocast context."""
return _QUANTIZE_CONFIG return _QUANTIZE_CONFIG
def get_quantize_config_class( def get_quantize_config_class(
fp8_recipe: recipe.Recipe, fp8_recipe: Recipe,
) -> Type[BaseQuantizeConfig]: ) -> Type[BaseQuantizeConfig]:
"""Get the quantization configuration based on the FP8 recipe. """Get the quantization configuration class based on the FP8 recipe.
Args: Args:
fp8_recipe: The FP8 recipe to use for initialization fp8_recipe: The FP8 recipe to use for initialization
Returns: Returns:
The quantization config class corresponding to the given recipe. The quantization config class corresponding to the given recipe.
""" """
if isinstance(fp8_recipe, recipe.DelayedScaling): if isinstance(fp8_recipe, DelayedScaling):
return DelayedScalingQuantizeConfig return DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): if isinstance(fp8_recipe, MXFP8BlockScaling):
return BlockScalingQuantizeConfig return BlockScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.Float8CurrentScaling): if isinstance(fp8_recipe, Float8CurrentScaling):
return CurrentScalingQuantizeConfig return CurrentScalingQuantizeConfig
if isinstance(fp8_recipe, NVFP4BlockScaling):
return NVFP4ScalingQuantizeConfig
raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}") raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}")
def get_quantize_config_with_recipe(fp8_recipe: Recipe):
"""Get the quantization configuration object based on the FP8 recipe."""
config = get_quantize_config_class(fp8_recipe)()
config.initialize_from_recipe(fp8_recipe)
return config
@contextmanager @contextmanager
def fp8_autocast( def autocast(
enabled: bool = False, enabled: bool = False,
fp8_recipe: Optional[recipe.Recipe] = None, recipe: Optional[Recipe] = None,
mesh_resource: Optional[MeshResource] = None, mesh_resource: Optional[MeshResource] = None,
) -> None: ) -> None:
r"""Context manager for FP8 automatic mixed precision. r"""Context manager for FP8 or FP4 usage.
This context manager enables FP8 quantization for the duration of its context. This context manager enables quantization for the duration of its context.
.. code-block:: python .. code-block:: python
mesh_shape = (4, 2) mesh_shape = (4, 2)
...@@ -412,7 +680,7 @@ def fp8_autocast( ...@@ -412,7 +680,7 @@ def fp8_autocast(
with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name) mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
with fp8_autocast(enabled=True, mesh_resource=mesh_resource): with autocast(enabled=True, mesh_resource=mesh_resource):
rules = extend_logical_axis_rules(tuple()) rules = extend_logical_axis_rules(tuple())
transformer = TransformerLayer() transformer = TransformerLayer()
...@@ -429,15 +697,15 @@ def fp8_autocast( ...@@ -429,15 +697,15 @@ def fp8_autocast(
---------- ----------
enabled: bool, default = False enabled: bool, default = False
Whether or not to enable fp8 Whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None recipe: recipe.DelayedScaling, default = None
Recipe used for FP8 training. recipe used for low precision quantization.
mesh_resource: MeshResource, default = None mesh_resource: MeshResource, default = None
Specify the mesh axes for data and tensor parallelism to shard along. Specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then no data or tensor parallelism will be used. If set to None, then no data or tensor parallelism will be used.
""" """
if fp8_recipe is None: if recipe is None:
fp8_recipe = recipe.DelayedScaling() recipe = DelayedScaling()
global _QUANTIZE_CONFIG global _QUANTIZE_CONFIG
...@@ -448,39 +716,44 @@ def fp8_autocast( ...@@ -448,39 +716,44 @@ def fp8_autocast(
try: try:
with global_shard_guard(mesh_resource): with global_shard_guard(mesh_resource):
if enabled: if enabled:
_QUANTIZE_CONFIG = get_quantize_config_class(fp8_recipe)() _QUANTIZE_CONFIG = get_quantize_config_class(recipe)()
is_supported, reason = _QUANTIZE_CONFIG.is_supported() is_supported, reason = _QUANTIZE_CONFIG.is_supported()
assert is_supported, reason assert is_supported, reason
_QUANTIZE_CONFIG.initialize_from_recipe(fp8_recipe) _QUANTIZE_CONFIG.initialize_from_recipe(recipe)
yield yield
finally: finally:
_QUANTIZE_CONFIG = old_quantize_config _QUANTIZE_CONFIG = old_quantize_config
def get_delayed_scaling(): @contextmanager
r""" def fp8_autocast(
Obtain an instance of DelayedScaling which is set via fp8_autocast. enabled: bool = False,
fp8_recipe: Optional[Recipe] = None,
mesh_resource: Optional[MeshResource] = None,
) -> None:
"""
.. warning::
.. note:: fp8_autocast is deprecated and will be removed in a future release.
We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len` Use autocast(enabled=..., recipe=..., mesh_resource=...) instead.
, and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in
recipe.DelayedScaling would be returned as the default values.
Returns
-------
delay_scaling : DelayedScaling
an instance of DelayedScaling which is set via fp8_autocast.
""" """
amax_compute_algo = (
"max" if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent" warnings.warn(
) "fp8_autocast is deprecated and will be removed in a future release. "
return recipe.DelayedScaling( "Use autocast(enabled=..., recipe=..., mesh_resource=...) instead.",
margin=int(get_quantize_config().MARGIN), category=DeprecationWarning,
fp8_format=get_quantize_config().FP8_FORMAT, stacklevel=2,
amax_history_len=get_quantize_config().AMAX_HISTORY_LEN,
amax_compute_algo=amax_compute_algo,
) )
# Call new implementation.
with autocast(
enabled=enabled,
recipe=fp8_recipe,
mesh_resource=mesh_resource,
):
yield
def update_collections(new: Collection, original: Collection) -> Collection: def update_collections(new: Collection, original: Collection) -> Collection:
r"""Update collections with new values while preserving original structure. r"""Update collections with new values while preserving original structure.
......
...@@ -9,23 +9,29 @@ This module provides classes for managing quantization metadata, including ...@@ -9,23 +9,29 @@ This module provides classes for managing quantization metadata, including
scale factors and amax history for different tensor types. scale factors and amax history for different tensor types.
""" """
from dataclasses import dataclass from dataclasses import dataclass
import jax.numpy as jnp
__all__ = ["QuantizeMeta", "QuantizeMetaSet"] __all__ = ["QuantizeMeta", "QuantizeMetaSet"]
@dataclass
class QuantizeMeta: class QuantizeMeta:
"""Metadata for quantization parameters. """Metadata for quantization parameters.
Attributes: For Delayed Scaling recipe:
scale: The scaling factor for quantization scale: The scaling factor for quantization
amax_history: History of maximum absolute values amax_history: History of maximum absolute values
For NVFP4 recipe with Stochastic Rounding:
sr_rng_state: The state of the stochastic rounding RNG
""" """
scale: jnp.ndarray def __init__(self, **kwargs):
amax_history: jnp.ndarray self._kwargs = kwargs
def get_kwargs_dictionary(self):
"""Get the metadata as a dictionary."""
return self._kwargs
@dataclass @dataclass
......
...@@ -19,6 +19,7 @@ from transformer_engine_jax import QuantizeLayout ...@@ -19,6 +19,7 @@ from transformer_engine_jax import QuantizeLayout
from transformer_engine.common import recipe from transformer_engine.common import recipe
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .hadamard import apply_rht, should_use_rht
from .tensor import ( from .tensor import (
ScaledTensor, ScaledTensor,
ScaledTensor1x, ScaledTensor1x,
...@@ -28,7 +29,7 @@ from .tensor import ( ...@@ -28,7 +29,7 @@ from .tensor import (
) )
from .helper import ( from .helper import (
get_quantize_config, get_quantize_config,
get_quantize_config_class, get_quantize_config_with_recipe,
AmaxComputeAlgo, AmaxComputeAlgo,
TensorSource, TensorSource,
) )
...@@ -66,6 +67,7 @@ def compute_scale_from_amax( ...@@ -66,6 +67,7 @@ def compute_scale_from_amax(
sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale)
assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}"
return sf return sf
...@@ -155,7 +157,7 @@ class Quantizer(ABC): ...@@ -155,7 +157,7 @@ class Quantizer(ABC):
""" """
def quantize( def quantize(
self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1, **kwargs self, x, is_rowwise=None, is_colwise=None, dq_dtype=None, flatten_axis=-1, **kwargs
) -> ScaledTensor: ) -> ScaledTensor:
"""Quantize a tensor using the internal _quantize_func(). """Quantize a tensor using the internal _quantize_func().
...@@ -170,6 +172,18 @@ class Quantizer(ABC): ...@@ -170,6 +172,18 @@ class Quantizer(ABC):
A ScaledTensor1x or ScaledTensor2x containing the quantized data A ScaledTensor1x or ScaledTensor2x containing the quantized data
""" """
del kwargs del kwargs
is_rowwise = (
is_rowwise
if is_rowwise is not None
else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
)
if (is_rowwise and is_colwise) or self.is_2x2x(): if (is_rowwise and is_colwise) or self.is_2x2x():
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = self._quantize_func( colwise_tensor = self._quantize_func(
...@@ -380,6 +394,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -380,6 +394,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / self.scale scale_inv = 1.0 / self.scale
amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,))
# Note, this updating of amax here will only be called once because the "quantize" method impl inherited from CurrentScaleQuantizer only calls _quantize_func once then transposes the result for colwise quantization. So we don't have to worry about update being called twice for 2x2x quantization.
self.update(amax) self.update(amax)
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
data=clipped_scaled_x, data=clipped_scaled_x,
...@@ -494,7 +509,7 @@ class BlockScaleQuantizer(Quantizer): ...@@ -494,7 +509,7 @@ class BlockScaleQuantizer(Quantizer):
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
x_shape = x.shape x_shape = x.shape
scale_shape = self.scaling_mode.get_scale_shape( scale_shape = self.scaling_mode.get_scale_shape(
x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis x_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
) )
scale_dtype = self.scaling_mode.get_scale_dtype() scale_dtype = self.scaling_mode.get_scale_dtype()
x = x.reshape( x = x.reshape(
...@@ -563,6 +578,221 @@ class BlockScaleQuantizer(Quantizer): ...@@ -563,6 +578,221 @@ class BlockScaleQuantizer(Quantizer):
return new_x.astype(dtype) return new_x.astype(dtype)
@register_pytree_node_class
@dataclass
class NVFP4Quantizer(Quantizer):
"""Quantizer implementation using current scaling.
This quantizer uses current scaling mode with float32 scales
Attributes:
scaling_mode: Set to NVFP4_1D_SCALING or NVFP4_2D_SCALING
q_layout: Quantization axis
data_layout: Data layout string (default: "NT")
stochastic_rounding_rng_state: RNG state for stochastic rounding, must be of shape (4,) and dtype uint32. If None, stochastic rounding is disabled.
"""
scaling_mode: ScalingMode = ScalingMode.NVFP4_1D_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
data_layout: str = "NT"
stochastic_rounding_rng_state: Optional[jnp.ndarray] = None
def __post_init__(self):
assert (
self.q_dtype == jnp.float4_e2m1fn
), "NVFP4 quantization must use a q_dtype of float4_e2m1fn"
assert self.scaling_mode.is_nvfp4_scaling, "NVFP4Quantizer must use NVFP4 scaling modes"
def _apply_stochastic_rounding(self, x):
assert (
self.stochastic_rounding_rng_state is not None
), "Stochastic rounding RNG state is not initialized"
assert self.stochastic_rounding_rng_state.shape == (
4,
), "Stochastic rounding RNG state must be of shape (4,)"
assert (
self.stochastic_rounding_rng_state.dtype == jnp.uint32
), "Stochastic rounding RNG state must be of dtype uint32"
# Default RNG state in JAX expects 2x 32-bit integers, use first 2 uint32s for initial state and fold in the other 2 uint32s
key_bits = jnp.array(
[
self.stochastic_rounding_rng_state[0],
self.stochastic_rounding_rng_state[1],
],
dtype=jnp.uint32,
)
key = jax.random.wrap_key_data(key_bits)
key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[2])
key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[3])
abs_x = jnp.abs(x)
sign_x = jnp.sign(x)
floor = (
(abs_x >= 0.5) * 0.5
+ (abs_x >= 1) * 0.5
+ (abs_x >= 2)
+ (abs_x >= 3)
+ (abs_x >= 4)
+ (abs_x >= 6) * 2
)
ceil = (
0.5
+ (abs_x > 0.5) * 0.5
+ (abs_x > 1) * 1
+ (abs_x > 2)
+ (abs_x > 3)
+ (abs_x > 4) * 2
)
frac = (abs_x - floor) / (ceil - floor)
rand = jax.random.uniform(key, abs_x.shape)
return sign_x * jnp.where(frac >= rand, ceil, floor)
def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Quantize function helper for block scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
# TODO(Phuong): use quantize_func from JAX
if flatten_axis < 0:
flatten_axis = x.ndim + flatten_axis
assert (
0 <= flatten_axis < x.ndim
), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}"
should_apply_rht = self.scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise
global_amax = None
if isinstance(x, NoScaleTensor):
global_amax = (
x.amax if not should_apply_rht else None
) # RHT changes the amax so don't use precalculated amax for colwise 1D nvfp4 quantization with RHT
x = x.data
# Transpose if required
rowwise_flatten_axis = flatten_axis
data_layout = self.data_layout[0]
if is_colwise:
x = jnp.transpose(x, (*range(flatten_axis, x.ndim), *range(flatten_axis)))
data_layout = self.data_layout[1]
# convert flatten_axis from N layout to T layout
flatten_axis = x.ndim - flatten_axis
x_shape = x.shape
if should_use_rht(self.scaling_mode, is_colwise=is_colwise):
# We only apply RHT for 1D colwise nvfp4
x = apply_rht(x)
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
scale_shape = self.scaling_mode.get_scale_shape(
x_shape,
data_layout=data_layout,
is_colwise=is_colwise,
is_padded=False,
flatten_axis=rowwise_flatten_axis,
)
scale_dtype = self.scaling_mode.get_scale_dtype()
x = x.reshape(
*x_shape[: flatten_axis - 1],
scale_shape[flatten_axis - 1],
int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*x_shape[flatten_axis:-1],
scale_shape[-1],
int(x_shape[-1] / scale_shape[-1]),
)
# Dtype max constants
DATA_DTYPE_MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32)
SCALE_DTYPE_MAX = jnp.finfo(scale_dtype).max.astype(jnp.float32)
# Level 1: Current Tensor Scaling
global_amax = (
global_amax
if global_amax is not None
else jnp.max(jnp.abs(x)).reshape((1,)).astype(jnp.float32)
)
tensor_scale = DATA_DTYPE_MAX * SCALE_DTYPE_MAX / global_amax
tensor_scale = jnp.minimum(
tensor_scale, jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32)
)
tensor_scale = jnp.where(
tensor_scale == jnp.array(0.0, dtype=jnp.float32),
jnp.array(1.0, dtype=jnp.float32),
tensor_scale,
)
tensor_scale_inv = 1.0 / tensor_scale
# Level 2: Block Scaling
block_amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True).astype(
jnp.float32
)
block_scale_inv = jnp.divide(block_amax, DATA_DTYPE_MAX)
block_scale_inv = block_scale_inv * tensor_scale
block_scale_inv = jnp.minimum(
block_scale_inv, jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32)
)
block_scale_inv = jnp.clip(block_scale_inv, -SCALE_DTYPE_MAX, SCALE_DTYPE_MAX)
# We cast block_scale_inv to scale_dtype here to account for any rounding during the cast. This will ensure the quantized data incorporates the rounded scale value into its computation so dequantization is accurate.
block_scale_inv = block_scale_inv.astype(scale_dtype)
# Note, with JIT jax removes this intermediate cast leading to slightly incorrect results during DQ and worse convergence to the original tensor during many samples of Q+SR->DQ. So we use reduce_precision to simulate the cast to scale_dtype.
assert scale_dtype == jnp.float8_e4m3fn, "Only float8_e4m3fn is supported for scale_dtype"
block_scale_inv = jax.lax.reduce_precision(block_scale_inv, 4, 3)
block_scale = jnp.minimum(
jnp.divide(1.0, block_scale_inv.astype(jnp.float32) * tensor_scale_inv),
jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32),
)
# Apply scaling
scaled_x = x.astype(jnp.float32) * block_scale
if self.stochastic_rounding_rng_state is not None:
scaled_x = self._apply_stochastic_rounding(scaled_x)
clipped_x = jnp.clip(scaled_x, -DATA_DTYPE_MAX, DATA_DTYPE_MAX)
# Cast to the right dtype
quantized_data = clipped_x.reshape(x_shape).astype(self.q_dtype)
block_scale_inv = block_scale_inv.reshape(scale_shape).astype(scale_dtype)
# In the 2D scaling mode, the scale shape is 2D but it needs to be broadcasted to 1D for GEMM.
# TODO(Phuong): expose this broadcast_2d_scale_shape_to_1d option to the
# quantizer.quantize() API
broadcasted_1d_scale_shape = self.scaling_mode.get_scale_shape(
x_shape,
data_layout=data_layout,
is_colwise=is_colwise,
is_padded=False,
flatten_axis=rowwise_flatten_axis,
broadcast_2d_scale_shape_to_1d=True,
)
# Broadcast and tile x to match the target shape
def repeat_to_shape(x, target_shape):
x_shape = x.shape
reps = [int(t // s) for s, t in zip(x_shape, target_shape)]
return jnp.tile(x, reps)
block_scale_inv = repeat_to_shape(block_scale_inv, broadcasted_1d_scale_shape)
return ScaledTensorFactory.create_1x(
data=quantized_data,
data_layout=data_layout,
is_colwise=is_colwise,
scale_inv=block_scale_inv,
amax=global_amax,
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
flatten_axis=rowwise_flatten_axis,
)
@register_pytree_node_class @register_pytree_node_class
@dataclass @dataclass
class QuantizerSet: class QuantizerSet:
...@@ -801,6 +1031,8 @@ class QuantizerFactory: ...@@ -801,6 +1031,8 @@ class QuantizerFactory:
ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer,
ScalingMode.CURRENT_TENSOR_SCALING: CurrentScaleQuantizer, ScalingMode.CURRENT_TENSOR_SCALING: CurrentScaleQuantizer,
ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer,
ScalingMode.NVFP4_1D_SCALING: NVFP4Quantizer,
ScalingMode.NVFP4_2D_SCALING: NVFP4Quantizer,
} }
@staticmethod @staticmethod
...@@ -826,7 +1058,6 @@ class QuantizerFactory: ...@@ -826,7 +1058,6 @@ class QuantizerFactory:
Returns: Returns:
A single quantizer or tuple of quantizers A single quantizer or tuple of quantizers
""" """
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type" assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type"
if n_groups: if n_groups:
if n_quantizers != 1: if n_quantizers != 1:
...@@ -887,18 +1118,9 @@ class QuantizerFactory: ...@@ -887,18 +1118,9 @@ class QuantizerFactory:
if "quantize_meta_set" in kwargs: if "quantize_meta_set" in kwargs:
quantize_meta_set = kwargs.get("quantize_meta_set") quantize_meta_set = kwargs.get("quantize_meta_set")
args_x = { args_x = quantize_meta_set.x.get_kwargs_dictionary()
"scale": quantize_meta_set.x.scale, args_kernel = quantize_meta_set.kernel.get_kwargs_dictionary()
"amax_history": quantize_meta_set.x.amax_history, args_grad = quantize_meta_set.grad.get_kwargs_dictionary()
}
args_kernel = {
"scale": quantize_meta_set.kernel.scale,
"amax_history": quantize_meta_set.kernel.amax_history,
}
args_grad = {
"scale": quantize_meta_set.grad.scale,
"amax_history": quantize_meta_set.grad.amax_history,
}
else: else:
args_x = args_kernel = args_grad = {} args_x = args_kernel = args_grad = {}
...@@ -919,6 +1141,7 @@ class QuantizerFactory: ...@@ -919,6 +1141,7 @@ class QuantizerFactory:
bwd_dtype: jnp.dtype = None, bwd_dtype: jnp.dtype = None,
is_2x2x: bool = None, is_2x2x: bool = None,
n_groups: int = None, n_groups: int = None,
# TODO(jberchtold): rename fp8_recipe to quantization_recipe
fp8_recipe: Optional[recipe.Recipe] = None, fp8_recipe: Optional[recipe.Recipe] = None,
**kwargs, **kwargs,
) -> tuple[Union[tuple[Quantizer], None]]: ) -> tuple[Union[tuple[Quantizer], None]]:
...@@ -946,11 +1169,14 @@ class QuantizerFactory: ...@@ -946,11 +1169,14 @@ class QuantizerFactory:
) )
if fp8_recipe is not None: if fp8_recipe is not None:
quantize_config = get_quantize_config_class(fp8_recipe)() quantize_config = get_quantize_config_with_recipe(fp8_recipe)
x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X) x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X)
kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL) kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD) grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD)
elif scaling_mode is not None: fwd_dtype = quantize_config.FWD_DTYPE
bwd_dtype = quantize_config.BWD_DTYPE
else:
if scaling_mode is not None:
x_scaling_mode = scaling_mode x_scaling_mode = scaling_mode
kernel_scaling_mode = scaling_mode kernel_scaling_mode = scaling_mode
grad_scaling_mode = scaling_mode grad_scaling_mode = scaling_mode
......
...@@ -17,7 +17,7 @@ from functools import reduce, lru_cache ...@@ -17,7 +17,7 @@ from functools import reduce, lru_cache
import operator import operator
import numpy as np import numpy as np
from jax.experimental.custom_partitioning import BATCHING from jax.experimental.custom_partitioning import BATCHING, CompoundFactor
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp import jax.numpy as jnp
...@@ -100,10 +100,19 @@ class ScalingModeMetadataImpl(ABC): ...@@ -100,10 +100,19 @@ class ScalingModeMetadataImpl(ABC):
The data type used for scale tensors The data type used for scale tensors
""" """
@abstractmethod
def get_data_layout(self) -> str:
"""Get the data layout for rowwise and colwise scaling.
Returns:
The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
"""
@abstractmethod @abstractmethod
def get_scale_shape( def get_scale_shape(
self, self,
data_shape: Tuple[int, ...], data_shape: Tuple[int, ...],
data_layout: str = "N",
is_colwise: bool = False, is_colwise: bool = False,
is_padded: bool = True, is_padded: bool = True,
flatten_axis: int = -1, flatten_axis: int = -1,
...@@ -112,6 +121,7 @@ class ScalingModeMetadataImpl(ABC): ...@@ -112,6 +121,7 @@ class ScalingModeMetadataImpl(ABC):
Args: Args:
data_shape: The shape of the tensor being quantized data_shape: The shape of the tensor being quantized
data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
is_colwise: Whether the scaling is column-wise is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape is_padded: Whether to return padded shape
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
...@@ -152,14 +162,19 @@ class ScalingModeMetadataImpl(ABC): ...@@ -152,14 +162,19 @@ class ScalingModeMetadataImpl(ABC):
@abstractmethod @abstractmethod
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self,
input_shape,
unique_var,
flatten_axis,
broadcast_2d_scale_shape_to_1d,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization. flatten_axis: Axis along which data can be flattened to 2D for quantization
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
...@@ -180,12 +195,22 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -180,12 +195,22 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
""" """
return jnp.float32 return jnp.float32
def get_data_layout(self) -> str:
"""Get the data layout for rowwise and colwise scaling.
Returns:
The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
"""
return "NN"
def get_scale_shape( def get_scale_shape(
self, self,
data_shape: Tuple[int, ...], data_shape: Tuple[int, ...],
data_layout: str = "N",
is_colwise: bool = False, is_colwise: bool = False,
is_padded: bool = True, is_padded: bool = True,
flatten_axis: int = -1, flatten_axis: int = -1,
broadcast_2d_scale_shape_to_1d: bool = True,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
"""Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling. """Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling.
...@@ -198,7 +223,14 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -198,7 +223,14 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
Returns: Returns:
The shape for scale tensors - (1,) The shape for scale tensors - (1,)
""" """
del data_shape, is_colwise, is_padded, flatten_axis del (
data_shape,
data_layout,
is_colwise,
is_padded,
flatten_axis,
broadcast_2d_scale_shape_to_1d,
)
return (0,) return (0,)
@lru_cache(maxsize=4) @lru_cache(maxsize=4)
...@@ -232,20 +264,25 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -232,20 +264,25 @@ class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_groups,) return (n_groups,)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self,
input_shape,
unique_var,
flatten_axis,
broadcast_2d_scale_shape_to_1d,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization. flatten_axis: Axis along which data can be flattened to 2D for quantization
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
del flatten_axis del flatten_axis, broadcast_2d_scale_shape_to_1d
input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv" scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
...@@ -264,25 +301,37 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -264,25 +301,37 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
""" """
return jnp.float32 return jnp.float32
def get_data_layout(self) -> str:
"""Get the data layout for rowwise and colwise scaling.
Returns:
The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
"""
return "NT"
def get_scale_shape( def get_scale_shape(
self, self,
data_shape: Tuple[int, ...], data_shape: Tuple[int, ...],
data_layout: str = "N",
is_colwise: bool = False, is_colwise: bool = False,
is_padded: bool = True, is_padded: bool = True,
flatten_axis: int = -1, flatten_axis: int = -1,
broadcast_2d_scale_shape_to_1d: bool = True,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
"""Get the shape for scale tensors in delayed scaling. """Get the shape for scale tensors in delayed scaling.
Args: Args:
data_shape: The shape of the tensor being scaled data_shape: The shape of the tensor being scaled
data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
is_colwise: Whether the scaling is column-wise is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True.
Returns: Returns:
The shape for scale tensors - (1,) The shape for scale tensors - (1,)
""" """
del is_colwise del data_layout, is_colwise, broadcast_2d_scale_shape_to_1d
if np.prod(data_shape) == 0: if np.prod(data_shape) == 0:
return (0,) return (0,)
return (1,) return (1,)
...@@ -323,20 +372,25 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -323,20 +372,25 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_groups,) return (n_groups,)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self,
input_shape,
unique_var,
flatten_axis,
broadcast_2d_scale_shape_to_1d,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization. flatten_axis: Axis along which data can be flattened to 2D for quantization
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
del flatten_axis del flatten_axis, broadcast_2d_scale_shape_to_1d
input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
scale_var = BATCHING + unique_var + "_scale_inv" scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
...@@ -359,14 +413,18 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -359,14 +413,18 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
_block_alignment: Alignment requirements for blocks _block_alignment: Alignment requirements for blocks
""" """
def __init__(self, block_dims: Tuple[int]): def __init__(self, block_dims: Tuple[int], scale_dtype: jnp.dtype, data_layout: str):
"""Initialize block scaling mode implementation. """Initialize block scaling mode implementation.
Args: Args:
block_dims: Dimensions of the scaling blocks block_dims: Dimensions of the scaling blocks
scale_dtype: Data type of the scale tensor
data_layout: Layout for rowwise and colwise scaling, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
""" """
self._block_dims = block_dims self._block_dims = block_dims
self._scale_dtype = scale_dtype
self._block_alignment = (128, 4) self._block_alignment = (128, 4)
self._data_layout = data_layout
def get_scale_dtype(self) -> jnp.dtype: def get_scale_dtype(self) -> jnp.dtype:
"""Get the data type for scale tensors in block scaling. """Get the data type for scale tensors in block scaling.
...@@ -374,7 +432,15 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -374,7 +432,15 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
Returns: Returns:
The data type used for scale tensors (float8_e8m0fnu) The data type used for scale tensors (float8_e8m0fnu)
""" """
return jnp.float8_e8m0fnu return self._scale_dtype
def get_data_layout(self) -> str:
"""Get the data layout for rowwise and colwise scaling.
Returns:
The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
"""
return self._data_layout
def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim): def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim):
"""Remove excess padding from the scale shape and return the shape with respect to the original data shape.""" """Remove excess padding from the scale shape and return the shape with respect to the original data shape."""
...@@ -402,23 +468,51 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -402,23 +468,51 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
def get_scale_shape( def get_scale_shape(
self, self,
data_shape: Tuple[int, ...], data_shape: Tuple[int, ...],
data_layout: str = "N",
is_colwise: bool = False, is_colwise: bool = False,
is_padded: bool = True, is_padded: bool = True,
flatten_axis: int = -1, flatten_axis: int = -1,
broadcast_2d_scale_shape_to_1d: bool = False,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
"""Get the shape for scale tensors in block scaling. """Get the shape for scale tensors in block scaling.
Args: Args:
data_shape: The shape of the tensor being quantized data_shape: The shape of the tensor being quantized
data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
is_colwise: Whether the scaling is column-wise is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True.
Returns: Returns:
The shape for scale tensors The shape for scale tensors
""" """
flatten_axis = (len(data_shape) + flatten_axis) % len(data_shape)
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
block_alignment = self._block_alignment if is_padded else (1, 1) block_alignment = self._block_alignment if is_padded else (1, 1)
if is_colwise:
assert data_layout == self._data_layout[1], (
f"Data layout must match colwise layout, received {data_layout} but expected"
f" {self._data_layout[1]}"
)
else:
assert data_layout == self._data_layout[0], (
f"Data layout must match rowwise layout, received {data_layout} but expected"
f" {self._data_layout[0]}"
)
if is_colwise and self._data_layout[1] == "T":
# TODO(Phuong): rework this hack so that we don't implicitly change is_colwise value
is_colwise = False # now rowwise in T is colwise in N
if flatten_axis < 0:
flatten_axis = len(data_shape) + flatten_axis
# flatten_axis is given wrt N layout, convert to T layout
flatten_axis = len(data_shape) - flatten_axis
if is_colwise: if is_colwise:
block_y, block_x = self._block_dims block_y, block_x = self._block_dims
alignment_y, alignment_x = block_alignment alignment_y, alignment_x = block_alignment
...@@ -426,12 +520,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -426,12 +520,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
block_x, block_y = self._block_dims block_x, block_y = self._block_dims
alignment_x, alignment_y = block_alignment alignment_x, alignment_y = block_alignment
if flatten_axis < 0: is_block_2d = block_x > 1 and block_y > 1
flatten_axis = len(data_shape) + flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
assert data_shape[flatten_axis - 1] % block_x == 0, ( assert data_shape[flatten_axis - 1] % block_x == 0, (
f"Data shape {data_shape} should be divisible by block_x {block_x} in axis" f"Data shape {data_shape} should be divisible by block_x {block_x} in axis"
f" {flatten_axis - 1}" f" {flatten_axis - 1}"
...@@ -440,6 +529,9 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -440,6 +529,9 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
data_shape[-1] % block_y == 0 data_shape[-1] % block_y == 0
), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1" ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1"
if broadcast_2d_scale_shape_to_1d and is_block_2d:
block_x = 1
flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1) flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1)
flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1) flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1)
...@@ -562,52 +654,67 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -562,52 +654,67 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (n_block_x * n_block_y,) return (n_block_x * n_block_y,)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self,
input_shape,
unique_var,
flatten_axis,
broadcast_2d_scale_shape_to_1d,
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
del flatten_axis # TODO(Phuong): to rework the shardy rule to handle transposes after NVFP4 is upstreamed
input_spec = [f"{unique_var}{i}" for i in range(input_rank)] input_rank = len(input_shape)
rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)] input_spec = [f"{unique_var}_{i}" for i in range(input_rank)]
colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)] flatten_axis = (flatten_axis + input_rank) % input_rank
# NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors.
# Unfortunately, because Shardy rules are applied to the inner primitive, the
# only way to preserve the relationship is to lower unpadded scales to the
# underlying custom call and pad them in C++. Until that's implemented, the
# Shardy rules for block scales have to be completely disconnected from the
# Shardy rules for the tensor they belong to.
# # We have to use two different factors in the two CompoundFactors because of Shardy
# # verifier requirements, even though they are the same.
# rowwise_var = unique_var
# colwise_var = f"{unique_var}_"
# input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise")
# input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise")
# # The rowwise and colwise scale tensors should be sharded the same way as the input.
# # However, we need to adjust the dimensions where the block scaling factor applies.
# rowwise = input_spec.copy()
# rowwise[-1] = rowwise_var
# colwise = input_spec.copy() assert (
# colwise[flatten_axis - 1] = colwise_var self._block_dims[1] != 1
), f"Expect 1D rowwise or 2D block. Got _block_dims={self._block_dims}"
# For 2D block scaling, only support when with broadcast_2d_scale_shape_to_1d
if self._block_dims[0] != 1:
assert self._block_dims[0] == self._block_dims[1] and broadcast_2d_scale_shape_to_1d, (
f"Got broadcast_2d_scale_shape_to_1d={broadcast_2d_scale_shape_to_1d},"
f" _block_dims={self._block_dims}"
)
# # This implementation needs to be updated for different block dims. block_size_1d = self._block_dims[1]
# assert self._block_dims == (1, 32)
# We have to use two different factors in the two CompoundFactors because of Shardy
# verifier requirements, even though they are the same.
blocksizes = {}
colwise_var = f"{unique_var}_None"
rowwise_var = f"{unique_var}_None"
if not input_shape[-1] == block_size_1d:
rowwise_var = input_spec[-1] + "_compound"
input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x")
blocksizes["blocksize_x"] = block_size_1d
if not input_shape[flatten_axis - 1] == block_size_1d:
colwise_var = input_spec[flatten_axis - 1] + "_compound"
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y")
blocksizes["blocksize_y"] = block_size_1d
# The rowwise and colwise scale tensors should be sharded the same way as the input.
# However, we need to adjust the dimensions where the block scaling factor applies.
rowwise = input_spec.copy()
rowwise[-1] = rowwise_var
colwise = input_spec.copy()
colwise[flatten_axis - 1] = colwise_var
return QuantizeShardyRules( return QuantizeShardyRules(
tuple(input_spec), tuple(input_spec),
tuple(rowwise), tuple(rowwise),
tuple(colwise), tuple(colwise),
{}, # {"block_size_rowwise": 32, "block_size_colwise": 32}, blocksizes,
) )
...@@ -620,6 +727,8 @@ class ScalingMode(Enum): ...@@ -620,6 +727,8 @@ class ScalingMode(Enum):
- DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- CURRENT_TENSOR_SCALING: Uses current scaling with FP8 data type and float32 scales - CURRENT_TENSOR_SCALING: Uses current scaling with FP8 data type and float32 scales
- NVFP4_1D_SCALING: Uses block-based scaling with FP4 data type and E4M3 scales
- NVFP4_2D_SCALING: Uses block-based scaling with FP4 data type and E4M3 scales
- NO_SCALING: No scaling applied - NO_SCALING: No scaling applied
""" """
...@@ -627,6 +736,8 @@ class ScalingMode(Enum): ...@@ -627,6 +736,8 @@ class ScalingMode(Enum):
DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING
MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING
CURRENT_TENSOR_SCALING = JAXX_Scaling_Mode.CURRENT_TENSOR_SCALING CURRENT_TENSOR_SCALING = JAXX_Scaling_Mode.CURRENT_TENSOR_SCALING
NVFP4_1D_SCALING = JAXX_Scaling_Mode.NVFP4_1D_SCALING
NVFP4_2D_SCALING = JAXX_Scaling_Mode.NVFP4_2D_SCALING
def _get_impl(self) -> ScalingModeMetadataImpl: def _get_impl(self) -> ScalingModeMetadataImpl:
"""Get the implementation for this scaling mode. """Get the implementation for this scaling mode.
...@@ -650,40 +761,79 @@ class ScalingMode(Enum): ...@@ -650,40 +761,79 @@ class ScalingMode(Enum):
""" """
return self._get_impl().get_scale_dtype() return self._get_impl().get_scale_dtype()
def get_scale_shape_2x(self, data_shape, is_padded=True, flatten_axis=-1) -> Tuple[Tuple[int]]: def get_scale_shape_2x(
self, data_shape, is_padded=True, flatten_axis=-1, broadcast_2d_scale_shape_to_1d=False
) -> Tuple[Tuple[int]]:
"""Get shapes for both row-wise and column-wise scaling. """Get shapes for both row-wise and column-wise scaling.
Args: Args:
data_shape: Shape of the data tensor data_shape: Shape of the data tensor
is_padded: Whether to use padded shapes is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False.
Returns: Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape) Tuple of (rowwise_scale_shape, colwise_scale_shape)
""" """
data_layout = self._get_impl().get_data_layout()
rowwise_layout = data_layout[0]
assert (
rowwise_layout == "N"
), f"For rowwise layout only 'N' is supported, received {rowwise_layout}"
colwise_layout = data_layout[1]
rowwise_scale_shape = self.get_scale_shape( rowwise_scale_shape = self.get_scale_shape(
data_shape, is_colwise=False, is_padded=is_padded, flatten_axis=flatten_axis data_shape,
data_layout=rowwise_layout,
is_colwise=False,
is_padded=is_padded,
flatten_axis=flatten_axis,
broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d,
) )
colwise_data_shape = data_shape
if colwise_layout == "T":
colwise_data_shape = data_shape[flatten_axis:] + data_shape[:flatten_axis]
colwise_scale_shape = self.get_scale_shape( colwise_scale_shape = self.get_scale_shape(
data_shape, is_colwise=True, is_padded=is_padded, flatten_axis=flatten_axis colwise_data_shape,
data_layout=colwise_layout,
is_colwise=True,
is_padded=is_padded,
flatten_axis=flatten_axis,
broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d,
) )
return (rowwise_scale_shape, colwise_scale_shape) return (rowwise_scale_shape, colwise_scale_shape)
def get_scale_shape( def get_scale_shape(
self, data_shape, is_colwise, is_padded=True, flatten_axis=-1 self,
data_shape,
data_layout="N",
is_colwise=False,
is_padded=True,
flatten_axis=-1,
broadcast_2d_scale_shape_to_1d=False,
) -> Tuple[int]: ) -> Tuple[int]:
"""Get the shape for scale tensors in this mode. """Get the shape for scale tensors in this mode.
Args: Args:
data_shape: Shape of the data tensor data_shape: Shape of the data tensor
data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
is_colwise: Whether to use column-wise scaling is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False.
Returns: Returns:
The shape for scale tensors The shape for scale tensors
""" """
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis) return self._get_impl().get_scale_shape(
data_shape,
data_layout=data_layout,
is_colwise=is_colwise,
is_padded=is_padded,
flatten_axis=flatten_axis,
broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d,
)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage. """Get the quantize layout for the tensor usage.
...@@ -697,18 +847,26 @@ class ScalingMode(Enum): ...@@ -697,18 +847,26 @@ class ScalingMode(Enum):
return self._get_impl().get_quantize_layout(usage) return self._get_impl().get_quantize_layout(usage)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis=-1 self,
input_shape,
unique_var,
flatten_axis=-1,
broadcast_2d_scale_shape_to_1d=False,
) -> Tuple[Tuple[str]]: ) -> Tuple[Tuple[str]]:
"""Sharding rules for the input and (row, col)wise scale tensors. """Sharding rules for the input and (row, col)wise scale tensors.
Args: Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor) input_shape: The shape of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False.
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis) return self._get_impl().get_shardy_sharding_rules(
input_shape, unique_var, flatten_axis, broadcast_2d_scale_shape_to_1d
)
def get_grouped_scale_shape_2x( def get_grouped_scale_shape_2x(
self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1 self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1
...@@ -782,8 +940,64 @@ class ScalingMode(Enum): ...@@ -782,8 +940,64 @@ class ScalingMode(Enum):
Returns: Returns:
True if the scaling mode is 1D block scaling, False otherwise True if the scaling mode is 1D block scaling, False otherwise
""" """
# Both 1D and 2D NVFP4 scaling are treated as 1D block scaling since the 2D scales are broadcast to 1D because it is required for the GEMM.
return self == ScalingMode.MXFP8_1D_SCALING or self.is_nvfp4_scaling
@property
def is_block_scaling(self) -> bool:
"""Check if this scaling mode is block scaling.
Returns:
True if the scaling mode is block scaling, False otherwise
"""
# Currently we only have 1D block scaling modes
return self.is_1d_block_scaling()
def get_compatible_q_dtypes(self) -> set[jnp.dtype]:
"""Returns a set of compatible quantized data types for this scaling mode.
Returns:
A set of compatible quantized data types
"""
if self in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
ScalingMode.MXFP8_1D_SCALING,
):
return {jnp.float8_e5m2, jnp.float8_e4m3fn}
if self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING):
return {jnp.float4_e2m1fn}
if self == ScalingMode.NO_SCALING:
return {jnp.float16, jnp.bfloat16, jnp.float32}
raise ValueError(f"Invalid scaling mode: {self}")
@property
def is_nvfp4_scaling(self) -> bool:
"""Check if this scaling mode is NVFP4 scaling.
Returns:
True if the scaling mode is NVFP4 scaling, False otherwise
"""
return self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING)
@property
def is_mxfp8_scaling(self) -> bool:
"""Check if this scaling mode is NVFP4 scaling.
Returns:
True if the scaling mode is NVFP4 scaling, False otherwise
"""
return self == ScalingMode.MXFP8_1D_SCALING return self == ScalingMode.MXFP8_1D_SCALING
@property
def is_colwise_transposed(self) -> bool:
"""Check if this scaling mode uses transposed layout for column-wise scaling.
Returns:
True if the scaling mode uses transposed layout for column-wise scaling, False otherwise
"""
return self.is_tensor_scaling() or self.is_nvfp4_scaling
def __eq__(self, other): def __eq__(self, other):
"""Compare this scaling mode with another. """Compare this scaling mode with another.
...@@ -820,9 +1034,20 @@ class ScalingMode(Enum): ...@@ -820,9 +1034,20 @@ class ScalingMode(Enum):
SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = { SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(),
ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(),
ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(
# WAR block_dims=(1, 32),
scale_dtype=jnp.float8_e8m0fnu,
data_layout="NN",
),
ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(), ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(),
ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(), ScalingMode.NVFP4_1D_SCALING: BlockScalingModeMetadataImpl(
block_dims=(1, 16),
scale_dtype=jnp.float8_e4m3fn,
data_layout="NT",
),
ScalingMode.NVFP4_2D_SCALING: BlockScalingModeMetadataImpl(
block_dims=(16, 16), scale_dtype=jnp.float8_e4m3fn, data_layout="NT"
),
} }
...@@ -201,13 +201,32 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): ...@@ -201,13 +201,32 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
else: else:
unpadded_scale_shape = self.scaling_mode.get_scale_shape( unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.data.shape,
data_layout=self.data_layout,
is_colwise=self.is_colwise, is_colwise=self.is_colwise,
is_padded=False, is_padded=False,
flatten_axis=self.flatten_axis, # expect the flatten_axis wrt the N layout
flatten_axis=(
self.flatten_axis
if self.data_layout == "N"
else self.data.ndim - self.flatten_axis
),
) )
assert self.scale_inv.shape == unpadded_scale_shape, ( unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape(
"Unpadded inverse scale factor has wrong shape, expected" self.data.shape,
f" {unpadded_scale_shape} but got {self.scale_inv.shape}." data_layout=self.data_layout,
is_colwise=self.is_colwise,
is_padded=False,
# expect the flatten_axis wrt the N layout
flatten_axis=(
self.flatten_axis
if self.data_layout == "N"
else self.data.ndim - self.flatten_axis
),
broadcast_2d_scale_shape_to_1d=True,
)
assert self.scale_inv.shape in (unpadded_scale_shape, unpadded_scale_shape_broadcast), (
f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or"
f" {unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}."
) )
def tree_flatten(self): def tree_flatten(self):
...@@ -583,6 +602,7 @@ class ScaledTensorFactory: ...@@ -583,6 +602,7 @@ class ScaledTensorFactory:
colwise_data, colwise_data,
colwise_scale_inv, colwise_scale_inv,
amax=None, amax=None,
colwise_amax=None,
scaling_mode=ScalingMode.NO_SCALING, scaling_mode=ScalingMode.NO_SCALING,
dq_dtype=jnp.bfloat16, dq_dtype=jnp.bfloat16,
data_layout="NN", data_layout="NN",
...@@ -612,6 +632,8 @@ class ScaledTensorFactory: ...@@ -612,6 +632,8 @@ class ScaledTensorFactory:
""" """
if amax is None: if amax is None:
amax = jnp.empty((1,), dtype=jnp.float32) amax = jnp.empty((1,), dtype=jnp.float32)
if colwise_amax is None:
colwise_amax = amax
assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}" assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
rowwise_tensor = ScaledTensorFactory.create_1x( rowwise_tensor = ScaledTensorFactory.create_1x(
...@@ -630,10 +652,10 @@ class ScaledTensorFactory: ...@@ -630,10 +652,10 @@ class ScaledTensorFactory:
colwise_tensor = ScaledTensorFactory.create_1x( colwise_tensor = ScaledTensorFactory.create_1x(
colwise_data, colwise_data,
colwise_scale_inv, colwise_scale_inv,
amax, colwise_amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
is_colwise=True, is_colwise=True, # TODO(Phuong): set this correctly
data_layout=data_layout[1], data_layout=data_layout[1],
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
group_sizes=group_sizes, group_sizes=group_sizes,
...@@ -649,6 +671,7 @@ class ScaledTensorFactory: ...@@ -649,6 +671,7 @@ class ScaledTensorFactory:
colwise_data: jnp.ndarray, colwise_data: jnp.ndarray,
colwise_scale_inv: jnp.ndarray, colwise_scale_inv: jnp.ndarray,
amax=None, amax=None,
colwise_amax=None,
scaling_mode: ScalingMode = ScalingMode.NO_SCALING, scaling_mode: ScalingMode = ScalingMode.NO_SCALING,
dq_dtype: jnp.dtype = jnp.bfloat16, dq_dtype: jnp.dtype = jnp.bfloat16,
data_layout: str = "NN", data_layout: str = "NN",
...@@ -684,6 +707,7 @@ class ScaledTensorFactory: ...@@ -684,6 +707,7 @@ class ScaledTensorFactory:
colwise_data, colwise_data,
colwise_scale_inv, colwise_scale_inv,
amax, amax,
colwise_amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
data_layout=data_layout, data_layout=data_layout,
...@@ -698,7 +722,7 @@ class ScaledTensorFactory: ...@@ -698,7 +722,7 @@ class ScaledTensorFactory:
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
colwise_data, colwise_data,
colwise_scale_inv, colwise_scale_inv,
amax, colwise_amax if colwise_amax is not None else amax,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
is_colwise=is_colwise, is_colwise=is_colwise,
......
...@@ -44,7 +44,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_ ...@@ -44,7 +44,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from build_tools.build_ext import get_build_ext from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers from build_tools.utils import copy_common_headers, min_python_version_str
from build_tools.te_version import te_version from build_tools.te_version import te_version
from build_tools.jax import setup_jax_extension, install_requirements, test_requirements from build_tools.jax import setup_jax_extension, install_requirements, test_requirements
...@@ -100,6 +100,7 @@ if __name__ == "__main__": ...@@ -100,6 +100,7 @@ if __name__ == "__main__":
description="Transformer acceleration library - Jax Lib", description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
python_requires=f">={min_python_version_str()}",
install_requires=install_requirements(), install_requires=install_requirements(),
tests_require=test_requirements(), tests_require=test_requirements(),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment