Commit a207db1d authored by yuguo's avatar yuguo
Browse files
parents fbee8990 69365f88
...@@ -81,5 +81,30 @@ inline size_t product(const xla::ffi::Span<const int64_t>& data, size_t start_id ...@@ -81,5 +81,30 @@ inline size_t product(const xla::ffi::Span<const int64_t>& data, size_t start_id
std::multiplies<size_t>()); std::multiplies<size_t>());
} }
inline static size_t te_dtype_bytes(const DType& type) {
switch (type) {
case DType::kByte:
return 1;
case DType::kInt32:
return 4;
case DType::kInt64:
return 8;
case DType::kFloat32:
return 4;
case DType::kFloat16:
return 2;
case DType::kBFloat16:
return 2;
case DType::kFloat8E5M2:
return 1;
case DType::kFloat8E4M3:
return 1;
case DType::kFloat8E8M0:
return 1;
default:
NVTE_ERROR("Unsupported DType: ", static_cast<int>(type));
}
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/gemm.h"
#include <memory>
#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#include "extensions.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
constexpr static size_t MXFP8_BLOCK_SIZE = 32;
// Note: we only support TN-GEMM for now (TN in cuBLASLt == NT in JAX)
Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lhs_sinv_ptr,
const DType &lhs_sinv_dtype, uint8_t *rhs_ptr, const DType &rhs_dtype,
uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr,
const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype,
uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms,
int32_t *dim_list_ptr, const int64_t &scaling_mode,
cudaStream_t stream) {
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype);
size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype);
size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype);
size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)");
NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)");
size_t dim_list_bytes = sizeof(int32_t) * 3 * num_gemms;
std::unique_ptr<int32_t[]> dim_list_host = std::make_unique<int32_t[]>(3 * num_gemms);
cudaMemcpyAsync(dim_list_host.get(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
// Notes on matrix layouts and transpose:
// Jax uses row-major layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k],
// B: row-major with size [n, k], needs transpose,
// on exiting this function, JAX expect:
// C: row-major with size [m, n].
// cuBLAS uses column-major layout, in this view, each input matrix pair:
// A: column-major with size [k, m], needs transpose,
// B: column-major with size [k, n].
// If we call cuBLAS GEMM for A * B, the output will be:
// C: column-major with size [m, n] --> row-major with size [n, m].
// To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.
bool trans_lhs = true;
bool trans_rhs = false;
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
bool grad = false;
bool accumulate = false;
bool use_split_accumulator = false;
// These lists are to keep the TensorWrapper objects alive
std::vector<TensorWrapper> lhs_wrapper_list;
std::vector<TensorWrapper> rhs_wrapper_list;
std::vector<TensorWrapper> bias_wrapper_list;
std::vector<TensorWrapper> pre_gelu_wrapper_list;
std::vector<TensorWrapper> out_wrapper_list;
std::vector<TensorWrapper> workspace_wrapper_list;
// These lists are the actual NVTETensor (void *) lists for multi-stream GEMM
std::vector<NVTETensor> lhs_list;
std::vector<NVTETensor> rhs_list;
std::vector<NVTETensor> bias_list;
std::vector<NVTETensor> pre_gelu_list;
std::vector<NVTETensor> out_list;
std::vector<NVTETensor> workspace_list;
for (int i = 0; i < num_gemms; i++) {
size_t m = dim_list_host[i * 3];
size_t n = dim_list_host[i * 3 + 1];
size_t k = dim_list_host[i * 3 + 2];
auto lhs_shape = std::vector<size_t>{m, k};
auto rhs_shape = std::vector<size_t>{n, k};
auto out_shape = std::vector<size_t>{n, m};
auto lhs_sinv_shape = std::vector<size_t>{1, 1};
auto rhs_sinv_shape = std::vector<size_t>{1, 1};
if (scaling_mode == NVTE_NO_SCALING || scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
auto lhs_i = TensorWrapper(static_cast<void *>(lhs_ptr), lhs_shape, lhs_dtype, nullptr,
nullptr, reinterpret_cast<float *>(lhs_sinv_ptr));
auto rhs_i = TensorWrapper(static_cast<void *>(rhs_ptr), rhs_shape, rhs_dtype, nullptr,
nullptr, reinterpret_cast<float *>(rhs_sinv_ptr));
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
} else if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)",
MXFP8_BLOCK_SIZE, k);
size_t sinv_k = k / MXFP8_BLOCK_SIZE;
lhs_sinv_shape[0] = m;
lhs_sinv_shape[1] = sinv_k;
rhs_sinv_shape[0] = n;
rhs_sinv_shape[1] = sinv_k;
// Note: the scale_inv array should have been swizzled in Python before lowering
TensorWrapper lhs_i(NVTE_MXFP8_1D_SCALING);
TensorWrapper rhs_i(NVTE_MXFP8_1D_SCALING);
lhs_i.set_rowwise_data(static_cast<void *>(lhs_ptr), lhs_dtype, lhs_shape);
rhs_i.set_rowwise_data(static_cast<void *>(rhs_ptr), rhs_dtype, rhs_shape);
lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat8E8M0,
lhs_sinv_shape);
rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat8E8M0,
rhs_sinv_shape);
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
} else {
NVTE_ERROR("Unsupported scaling mode: ", scaling_mode);
}
auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype);
lhs_ptr += m * k * lhs_dtype_bytes;
rhs_ptr += n * k * rhs_dtype_bytes;
out_ptr += m * n * out_dtype_bytes;
lhs_sinv_ptr += lhs_sinv_shape[0] * lhs_sinv_shape[1] * lhs_sinv_dtype_bytes;
rhs_sinv_ptr += rhs_sinv_shape[0] * rhs_sinv_shape[1] * rhs_sinv_dtype_bytes;
void *pre_gelu_ptr = nullptr;
auto bias_shape = std::vector<size_t>{0};
auto pre_gelu_shape = std::vector<size_t>{0};
if (bias_ptr != nullptr) bias_shape[0] = n;
auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
if (bias_ptr != nullptr) bias_ptr += n * bias_dtype_bytes;
auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype);
out_wrapper_list.push_back(std::move(out_i));
bias_wrapper_list.push_back(std::move(bias_i));
pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i));
lhs_list.push_back(lhs_wrapper_list.back().data());
rhs_list.push_back(rhs_wrapper_list.back().data());
bias_list.push_back(bias_wrapper_list.back().data());
pre_gelu_list.push_back(pre_gelu_wrapper_list.back().data());
out_list.push_back(out_wrapper_list.back().data());
}
auto workspace_shape = std::vector<size_t>{workspace_size};
for (int i = 0; i < num_streams; i++) {
auto workspace_i =
TensorWrapper(static_cast<void *>(workspace_ptr), workspace_shape, DType::kByte);
workspace_wrapper_list.push_back(std::move(workspace_i));
workspace_list.push_back(workspace_wrapper_list.back().data());
workspace_ptr += workspace_size;
}
nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(),
pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad,
workspace_list.data(), accumulate, use_split_accumulator,
num_math_sm, stream);
return ffi_with_cuda_error_check();
}
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten,
Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten,
Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten,
Buffer_Type dim_list, Result_Type out_flatten,
Result_Type workspace_flatten, int64_t num_gemms, int64_t scaling_mode) {
// Inputs
auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_flatten.untyped_data());
auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_flatten.untyped_data());
auto lhs_sinv_ptr = reinterpret_cast<uint8_t *>(lhs_sinv_flatten.untyped_data());
auto rhs_sinv_ptr = reinterpret_cast<uint8_t *>(rhs_sinv_flatten.untyped_data());
auto bias_ptr = reinterpret_cast<uint8_t *>(bias_flatten.untyped_data());
auto dim_list_ptr = reinterpret_cast<int32_t *>(dim_list.untyped_data());
auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_flatten.element_type());
auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_flatten.element_type());
auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv_flatten.element_type());
auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv_flatten.element_type());
auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias_flatten.element_type());
// Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(out_flatten->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(out_flatten->element_type());
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace_flatten->untyped_data());
auto workspace_size = workspace_flatten->dimensions().back() / num_streams;
return GroupedGemmImpl(lhs_ptr, lhs_dtype, lhs_sinv_ptr, lhs_sinv_dtype, rhs_ptr, rhs_dtype,
rhs_sinv_ptr, rhs_sinv_dtype, bias_ptr, bias_dtype, out_ptr, out_dtype,
workspace_ptr, workspace_size, num_gemms, dim_list_ptr, scaling_mode,
stream);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // lhs_flatten
.Arg<Buffer_Type>() // lhs_sinv_flatten
.Arg<Buffer_Type>() // rhs_flatten
.Arg<Buffer_Type>() // rhs_sinv_flatten
.Arg<Buffer_Type>() // bias_flatten
.Arg<Buffer_Type>() // dim_list
.Ret<Buffer_Type>() // out_flatten
.Ret<Buffer_Type>() // workspace_flatten
.Attr<int64_t>("num_gemms")
.Attr<int64_t>("scaling_mode"),
FFI_CudaGraph_Traits);
} // namespace jax
} // namespace transformer_engine
...@@ -34,5 +34,11 @@ inline size_t product(const std::vector<size_t> &shape) { ...@@ -34,5 +34,11 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret; return ret;
} }
enum class QuantizeAxis {
ROWWISE,
COLWISE,
ROWWISE_COLWISE,
};
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -5,15 +5,18 @@ ...@@ -5,15 +5,18 @@
************************************************************************/ ************************************************************************/
#include "transformer_engine/normalization.h" #include "transformer_engine/normalization.h"
#include <cuda_runtime.h>
#include "extensions.h" #include "extensions.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType in_dtype, DType w_dtype, DType out_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma, NVTE_Norm_Type norm_type, int scaling_mode,
float eps, int sm_margin) { bool zero_centered_gamma, float epsilon, int sm_margin,
bool is_training) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size}; auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size}; auto intermediates_shape = std::vector<size_t>{batch_size};
...@@ -21,23 +24,32 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd ...@@ -21,23 +24,32 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd
// empty tensor wrappers are okay just to get workspace size // empty tensor wrappers are okay just to get workspace size
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype); auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype);
auto output_tensor = TensorWrapper(nullptr, input_shape, out_dtype);
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
auto _scaling_mode = static_cast<NVTEScalingMode>(scaling_mode);
auto output_tensor = TensorWrapper(_scaling_mode);
output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape);
// WAR: NVTE Norms query the is_training from whereas columwise_data is allocated
if (is_training && _scaling_mode == NVTE_MXFP8_1D_SCALING) {
int temp = 1;
output_tensor.set_columnwise_data(static_cast<void *>(&temp), out_dtype, input_shape);
}
// dummy tensor wrappers that will carry workspace size info later // dummy tensor wrappers that will carry workspace size info later
TensorWrapper dummy_work_tensor; TensorWrapper dummy_work_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
if (is_layer_norm) { if (norm_type == NVTE_Norm_Type::LayerNorm) {
auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), epsilon,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr); dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr);
} else { } else {
// TODO(Phuong): Verify and remove this check NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma,
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), epsilon, output_tensor.data(),
rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma,
nullptr); nullptr);
} }
...@@ -46,232 +58,125 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd ...@@ -46,232 +58,125 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd
return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype())); return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()));
} }
void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspace_size, Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
bool zero_centered_gamma, float eps, void *input, DType in_dtype, Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf,
void *weight, DType w_dtype, void *bias, void *output, DType out_dtype, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
void *workspace, DType work_dtype, void *mu, void *rsigma, float *amax, Result_Type colwise_scale_inv_buf, Result_Type amax_buf,
float *scale, float *scale_inv, int sm_margin, cudaStream_t stream) { Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf,
int norm_type, bool zero_centered_gamma, double epsilon,
int64_t sm_margin, int scaling_mode, bool is_2x) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype(wkspace_buf->element_type());
auto *input = x_buf.untyped_data();
auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
auto *gamma = gamma_buf.untyped_data();
auto *beta = beta_buf.untyped_data();
auto *output = output_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data();
auto *mu = mu_buf->untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto *workspace = wkspace_buf->untyped_data();
auto _scaling_mode = static_cast<NVTEScalingMode>(scaling_mode);
auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type);
auto _is_2x = static_cast<bool>(is_2x);
auto x_size = product(x_buf.dimensions());
auto gamma_size = product(gamma_buf.dimensions());
auto workspace_size = product(wkspace_buf->dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
float _epsilon = static_cast<float>(epsilon);
int _sm_margin = static_cast<int>(sm_margin);
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size}; auto gamma_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size}; auto intermediates_shape = std::vector<size_t>{batch_size};
auto workspace_shape = std::vector<size_t>{workspace_size}; auto workspace_shape = std::vector<size_t>{workspace_size};
auto is_layer_norm = (bias) ? true : false;
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(weight, weight_shape, in_dtype); auto gamma_tensor = TensorWrapper(gamma, gamma_shape, in_dtype);
// assume output dtype = input dtype
// If we need mixed I/O precision in the future, we need an additional
// parameter for output type
auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv);
auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32); auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin;
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; auto output_tensor = TensorWrapper(_scaling_mode);
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype); if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
}
if (_scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
if (_is_2x) {
output_tensor.set_columnwise_data(colwise_output_buf->untyped_data(),
static_cast<DType>(out_dtype), input_shape);
output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()});
}
if (is_layer_norm) { if (_norm_type == NVTE_Norm_Type::LayerNorm) {
auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype); auto beta_tensor = TensorWrapper(beta, gamma_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32); auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), _epsilon,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
workspace_tensor.data(), num_sm, zero_centered_gamma, stream); workspace_tensor.data(), num_sm, zero_centered_gamma, stream);
} else { } else {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma,
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), _epsilon, output_tensor.data(),
rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma,
stream); stream);
} }
}
Error_Type LayerNormForwardImplFFI(cudaStream_t stream, Buffer_Type *x_buf, Buffer_Type *gamma_buf,
Buffer_Type *beta_buf, Buffer_Type *amax_buf,
Buffer_Type *scale_buf, Buffer_Type *scale_inv_buf,
Result_Type *output_buf, Result_Type *mu_buf,
Result_Type *rsigma_buf, Result_Type *amax_out_buf,
Result_Type *wkspace_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_, bool is_layer_norm, bool is_fp8) {
auto in_dtype = convert_ffi_datatype_to_te_dtype((*x_buf).element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype((*gamma_buf).element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type());
auto *input = x_buf->untyped_data();
auto *weight = gamma_buf->untyped_data();
auto *output = (*output_buf)->untyped_data();
auto *rsigma = (*rsigma_buf)->untyped_data();
auto *workspace = (*wkspace_buf)->untyped_data();
void *bias = nullptr;
void *mu = nullptr;
if (is_layer_norm) {
bias = beta_buf->untyped_data();
mu = (*mu_buf)->untyped_data();
}
float *amax = nullptr;
float *scale = nullptr;
float *scale_inv = nullptr;
void *amax_out = nullptr;
auto out_dtype = in_dtype;
if (is_fp8) {
amax = reinterpret_cast<float *>(amax_buf->untyped_data());
scale = reinterpret_cast<float *>(scale_buf->untyped_data());
scale_inv = reinterpret_cast<float *>(scale_inv_buf->untyped_data());
amax_out = (*amax_out_buf)->untyped_data();
NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX LayerNormForward primitive");
out_dtype = DType::kFloat8E4M3;
}
auto x_size = product(x_buf->dimensions());
auto gamma_size = product(gamma_buf->dimensions());
auto wkspace_size = product((*wkspace_buf)->dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype,
mu, rsigma, amax, scale, scale_inv, sm_margin, stream);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf, XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
Buffer_Type beta_buf, Buffer_Type amax_buf, Buffer_Type scale_buf,
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type amax_out_buf, Result_Type wkspace_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf, &amax_buf, &scale_buf,
&scale_inv_buf, &output_buf, &mu_buf, &rsigma_buf, &amax_out_buf,
&wkspace_buf, zero_centered_gamma, eps_, sm_margin_,
true, // is_layer_norm
true // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI,
FFI::Bind() FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x .Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type LayerNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Result_Type output_buf, Result_Type mu_buf,
Result_Type rsigma_buf, Result_Type wkspace_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf,
nullptr, // amax_buf
nullptr, // scale_buf,
nullptr, // scale_inv_buf,
&output_buf, &mu_buf, &rsigma_buf,
nullptr, // amax_out_buf,
&wkspace_buf, zero_centered_gamma, eps_, sm_margin_,
true, // is_layer_norm
false // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardHandler, LayerNormForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma .Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta .Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise_output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // colwise_scale_inv
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // mu .Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma .Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("norm_type")
.Attr<bool>("zero_centered_gamma") .Attr<bool>("zero_centered_gamma")
.Attr<double>("eps") .Attr<double>("epsilon")
.Attr<int64_t>("sm_margin"), .Attr<int64_t>("sm_margin")
FFI_CudaGraph_Traits); .Attr<int64_t>("scaling_mode")
.Attr<bool>("is_2x"),
Error_Type RMSNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type amax_buf, Buffer_Type scale_buf,
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type rsigma_buf, Result_Type amax_out_buf,
Result_Type wkspace_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf,
nullptr, // beta_buf,
&amax_buf, &scale_buf, &scale_inv_buf, &output_buf,
nullptr, // mu_buf,
&rsigma_buf, &amax_out_buf, &wkspace_buf, zero_centered_gamma,
eps_, sm_margin_,
false, // is_layer_norm
true // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardFP8Handler, RMSNormForwardFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type RMSNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Result_Type output_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf,
nullptr, // beta_buf,
nullptr, // amax_buf,
nullptr, // scale_buf,
nullptr, // scale_inv_buf,
&output_buf,
nullptr, // mu_buf,
&rsigma_buf,
nullptr, // amax_out_buf,
&wkspace_buf, zero_centered_gamma, eps_, sm_margin_,
false, // is_layer_norm
false // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardHandler, RMSNormForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType in_dtype, DType w_dtype, DType w_dtype, NVTE_Norm_Type norm_type,
bool is_layer_norm, bool zero_centered_gamma, bool zero_centered_gamma, int sm_margin) {
float eps, int sm_margin) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size}; auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size}; auto intermediates_shape = std::vector<size_t>{batch_size};
...@@ -289,7 +194,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid ...@@ -289,7 +194,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
TensorWrapper dummy_work_tensor; TensorWrapper dummy_work_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
if (is_layer_norm) { if (norm_type == NVTE_Norm_Type::LayerNorm) {
auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype); auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
...@@ -309,16 +214,37 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid ...@@ -309,16 +214,37 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype())); return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()));
} }
void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size, Error_Type NormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
bool zero_centered_gamma, float eps, void *input, DType in_dtype, Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf,
void *weight, DType w_dtype, void *ograd, void *workspace, Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf,
DType wkspace_dtype, void *mu, void *rsigma, void *xgrad, void *wgrad, Result_Type wkspace_buf, int64_t norm_type, bool zero_centered_gamma,
void *dbeta, int sm_margin, cudaStream_t stream) { int64_t sm_margin) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype(wkspace_buf->element_type());
auto *ograd = dz_buf.untyped_data();
auto *input = x_buf.untyped_data();
void *mu = mu_buf.untyped_data();
auto *rsigma = rsigma_buf.untyped_data();
auto *gamma = gamma_buf.untyped_data();
auto *xgrad = xgrad_buf->untyped_data();
auto *wgrad = wgrad_buf->untyped_data();
void *dbeta = dbeta_buf->untyped_data();
auto *workspace = wkspace_buf->untyped_data();
auto x_size = product(x_buf.dimensions());
auto gamma_size = product(gamma_buf.dimensions());
auto wkspace_size = product(wkspace_buf->dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
int _sm_margin = static_cast<int>(sm_margin);
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size}; auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size}; auto intermediates_shape = std::vector<size_t>{batch_size};
auto intermediates_dtype = DType::kFloat32; auto intermediates_dtype = DType::kFloat32;
auto is_layer_norm = (dbeta) ? true : false;
// assume input type = output type // assume input type = output type
auto *grad_output = ograd; auto *grad_output = ograd;
...@@ -327,19 +253,18 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace ...@@ -327,19 +253,18 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, intermediates_dtype); auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, intermediates_dtype);
auto *x = input; auto x_tensor = TensorWrapper(input, input_shape, x_dtype);
auto x_tensor = TensorWrapper(x, input_shape, x_dtype);
auto gamma_tensor = TensorWrapper(weight, weight_shape, w_dtype); auto gamma_tensor = TensorWrapper(gamma, weight_shape, w_dtype);
auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype); auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype);
auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype); auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin;
auto workspace_shape = std::vector<size_t>{wkspace_size}; auto workspace_shape = std::vector<size_t>{wkspace_size};
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
if (is_layer_norm) { if (static_cast<NVTE_Norm_Type>(norm_type) == NVTE_Norm_Type::LayerNorm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype); auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype); auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
...@@ -353,61 +278,11 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace ...@@ -353,61 +278,11 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
xgrad_tensor.data(), wgrad_tensor.data(), workspace_tensor.data(), num_sm, xgrad_tensor.data(), wgrad_tensor.data(), workspace_tensor.data(), num_sm,
zero_centered_gamma, stream); zero_centered_gamma, stream);
} }
}
Error_Type LayerNormBackwardImplFFI(cudaStream_t stream, Buffer_Type *dz_buf, Buffer_Type *x_buf,
Buffer_Type *mu_buf, Buffer_Type *rsigma_buf,
Buffer_Type *gamma_buf, Result_Type *xgrad_buf,
Result_Type *wgrad_buf, Result_Type *dbeta_buf,
Result_Type *wkspace_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_, bool is_layer_norm) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf->element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type());
auto *ograd = dz_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data();
auto *input = x_buf->untyped_data();
auto *weight = gamma_buf->untyped_data();
auto *xgrad = (*xgrad_buf)->untyped_data();
auto *wgrad = (*wgrad_buf)->untyped_data();
auto *workspace = (*wkspace_buf)->untyped_data();
void *mu = nullptr;
void *dbeta = nullptr;
if (is_layer_norm) {
mu = (*mu_buf).untyped_data();
dbeta = (*dbeta_buf)->untyped_data();
}
auto x_size = product(x_buf->dimensions());
auto gamma_size = product(gamma_buf->dimensions());
auto wkspace_size = product((*wkspace_buf)->dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma,
xgrad, wgrad, dbeta, sm_margin, stream);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf, XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardHandler, NormBackwardFFI,
Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf,
Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf,
Result_Type wkspace_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf, &mu_buf, &rsigma_buf, &gamma_buf,
&xgrad_buf, &wgrad_buf, &dbeta_buf, &wkspace_buf,
zero_centered_gamma, eps_, sm_margin_,
true // is_layer_norm
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI,
FFI::Bind() FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz .Arg<Buffer_Type>() // dz
...@@ -419,220 +294,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI, ...@@ -419,220 +294,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI,
.Ret<Buffer_Type>() // wgrad .Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // dbeta .Ret<Buffer_Type>() // dbeta
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("norm_type")
.Attr<bool>("zero_centered_gamma") .Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"), .Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type RMSNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type rsigma_buf, Buffer_Type gamma_buf, Result_Type xgrad_buf,
Result_Type wgrad_buf, Result_Type wkspace_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf,
nullptr, // mu_buf
&rsigma_buf, &gamma_buf, &xgrad_buf, &wgrad_buf,
nullptr, // dbeta_buf,
&wkspace_buf, zero_centered_gamma, eps_, sm_margin_,
false // is_layer_norm
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormBackwardHandler, RMSNormBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // wkspace
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *bias = buffers[2];
auto *amax = reinterpret_cast<float *>(buffers[3]);
auto *scale = reinterpret_cast<float *>(buffers[4]);
auto *scale_inv = reinterpret_cast<float *>(buffers[5]);
auto *output = buffers[6];
auto *mu = buffers[7];
auto *rsigma = buffers[8];
auto *amax_out = buffers[9];
auto *workspace = buffers[10];
NVTE_CHECK(amax_out == amax,
"amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive");
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype,
mu, rsigma, amax, scale, scale_inv, sm_margin, stream);
}
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *bias = buffers[2];
auto *output = buffers[3];
auto *mu = buffers[4];
auto *rsigma = buffers[5];
auto *workspace = buffers[6];
float *amax = nullptr;
float *scale = nullptr;
float *scale_inv = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto eps = desc.eps;
auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype,
mu, rsigma, amax, scale, scale_inv, sm_margin, stream);
}
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto *ograd = buffers[0];
auto *mu = buffers[1];
auto *rsigma = buffers[2];
auto *input = buffers[3];
auto *weight = buffers[4];
auto *xgrad = buffers[5];
auto *wgrad = buffers[6];
auto *dbeta = buffers[7];
auto *workspace = buffers[8];
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma,
xgrad, wgrad, dbeta, sm_margin, stream);
}
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *amax = reinterpret_cast<float *>(buffers[2]);
auto *scale = reinterpret_cast<float *>(buffers[3]);
auto *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *rsigma = buffers[6];
auto *amax_out = buffers[7];
auto *workspace = buffers[8];
NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX RSMNormForwardFP8 primitive.");
void *bias = nullptr;
void *mu = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype,
mu, rsigma, amax, scale, scale_inv, sm_margin, stream);
}
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *output = buffers[2];
auto *rsigma = buffers[3];
auto *workspace = buffers[4];
void *bias = nullptr;
void *mu = nullptr;
float *amax = nullptr;
float *scale = nullptr;
float *scale_inv = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = in_dtype;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype,
mu, rsigma, amax, scale, scale_inv, sm_margin, stream);
}
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *ograd = buffers[0];
auto *rsigma = buffers[1];
auto *input = buffers[2];
auto *weight = buffers[3];
auto *xgrad = buffers[4];
auto *wgrad = buffers[5];
auto *workspace = buffers[6];
void *mu = nullptr;
void *dbeta = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input,
in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma,
xgrad, wgrad, dbeta, sm_margin, stream);
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace transformer_engine {
namespace jax {
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype, size_t act_enum) {
CustomCallCommonDescriptor desc{};
desc.shape.from_vector(shape);
desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype;
desc.act_enum = act_enum;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype, size_t act_enum) {
CustomCallCommonWkDescriptor desc{};
desc.shape.from_vector(shape);
desc.wkshape.from_vector(wkshape);
desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype;
desc.wk_dtype = wk_dtype;
desc.act_enum = act_enum;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, DType x_dtype, DType w_dtype,
DType wkspace_dtype, bool zero_centered_gamma,
float eps, int sm_margin) {
CustomCallNormDescriptor desc{};
desc.batch_size = batch_size;
desc.hidden_size = hidden_size;
desc.wkspace_size = wkspace_size;
desc.x_dtype = x_dtype;
desc.w_dtype = w_dtype;
desc.wkspace_dtype = wkspace_dtype;
desc.zero_centered_gamma = zero_centered_gamma;
desc.eps = eps;
desc.sm_margin = sm_margin;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
size_t head_dim, size_t q_seqlen, size_t k_seqlen,
DType dtype, float scale_factor) {
return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen, dtype,
scale_factor});
}
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic, int64_t window_size_left, int64_t window_size_right) {
return PackOpaque(
CustomCallFusedAttnDescriptor{input_batch, bias_batch, q_max_seqlen,
kv_max_seqlen, attn_heads, num_gqa_groups,
bias_heads, head_dim, max_segments_per_seq,
wkspace_size, scaling_factor, dropout_probability,
bias_type, mask_type, qkv_layout,
dtype, wkspace_dtype, is_training,
deterministic, window_size_left, window_size_right});
}
} // namespace jax
} // namespace transformer_engine
...@@ -9,11 +9,6 @@ ...@@ -9,11 +9,6 @@
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
template <typename T>
pybind11::capsule EncapsulateFunction(T *fn) {
return pybind11::capsule(reinterpret_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
}
template <typename T> template <typename T>
pybind11::capsule EncapsulateFFI(T *fn) { pybind11::capsule EncapsulateFFI(T *fn) {
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>, static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
...@@ -23,49 +18,13 @@ pybind11::capsule EncapsulateFFI(T *fn) { ...@@ -23,49 +18,13 @@ pybind11::capsule EncapsulateFFI(T *fn) {
pybind11::dict Registrations() { pybind11::dict Registrations() {
pybind11::dict dict; pybind11::dict dict;
dict["te_transpose"] = EncapsulateFunction(Transpose);
dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose);
dict["te_act_lu"] = EncapsulateFunction(ActLu);
dict["te_act_lu_fp8"] = EncapsulateFunction(ActLuFP8);
dict["te_dact_lu"] = EncapsulateFunction(DActLu);
dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose);
dict["te_dact_lu_dbias_cast_transpose"] = EncapsulateFunction(DActLuDBiasCastTranspose);
dict["te_dgated_act_lu_cast_transpose"] = EncapsulateFunction(DGatedActLuCastTranspose);
dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward);
dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8);
dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward);
dict["te_rmsnorm_forward"] = EncapsulateFunction(RMSNormForward);
dict["te_rmsnorm_forward_fp8"] = EncapsulateFunction(RMSNormForwardFP8);
dict["te_rmsnorm_backward"] = EncapsulateFunction(RMSNormBackward);
dict["te_quantize"] = EncapsulateFunction(Quantize);
dict["te_dequantize"] = EncapsulateFunction(Dequantize);
dict["te_scaled_softmax_forward"] = EncapsulateFunction(ScaledSoftmaxForward);
dict["te_scaled_softmax_backward"] = EncapsulateFunction(ScaledSoftmaxBackward);
dict["te_scaled_masked_softmax_forward"] = EncapsulateFunction(ScaledMaskedSoftmaxForward);
dict["te_scaled_masked_softmax_backward"] = EncapsulateFunction(ScaledMaskedSoftmaxBackward);
dict["te_scaled_upper_triang_masked_softmax_forward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward);
dict["te_scaled_upper_triang_masked_softmax_backward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
// Transpose
dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler);
dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler);
dict["te_dbias_cast_transpose_ffi"] = EncapsulateFFI(DBiasCastTransposeHandler);
// Activation // Activation
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_dact_dbias_quantize_ffi"] = EncapsulateFFI(DActLuDBiasQuantizeHandler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
dict["te_dact_lu_dbias_cast_transpose_ffi"] = EncapsulateFFI(DActLuDBiasCastTransposeHandler);
dict["te_dgated_act_lu_cast_transpose_ffi"] = EncapsulateFFI(DGatedActLuCastTransposeHandler);
// Quantization // Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler);
dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler);
// Softmax // Softmax
...@@ -80,58 +39,40 @@ pybind11::dict Registrations() { ...@@ -80,58 +39,40 @@ pybind11::dict Registrations() {
EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler); EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Normalization // Normalization
dict["te_layernorm_forward_ffi"] = dict["te_norm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler));
dict["te_layernorm_forward_fp8_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler));
dict["te_layernorm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler));
dict["te_rmsnorm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler));
dict["te_rmsnorm_forward_fp8_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler)); pybind11::arg("execute") = EncapsulateFFI(NormForwardHandler));
dict["te_rmsnorm_backward_ffi"] = dict["te_norm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler)); pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler));
// Attention // Attention
pybind11::dict fused_attn_forward_ffi; dict["te_fused_attn_forward_ffi"] =
fused_attn_forward_ffi["prepare"] = EncapsulateFFI(CudnnHandleInitHandler); pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
fused_attn_forward_ffi["execute"] = EncapsulateFFI(FusedAttnForwardHandler); pybind11::arg("execute") = EncapsulateFFI(FusedAttnForwardHandler));
dict["te_fused_attn_forward_ffi"] = fused_attn_forward_ffi; dict["te_fused_attn_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler));
pybind11::dict fused_attn_backward_ffi; // Grouped GEMM
fused_attn_backward_ffi["prepare"] = EncapsulateFFI(CudnnHandleInitHandler); dict["te_grouped_gemm_ffi"] =
fused_attn_backward_ffi["execute"] = EncapsulateFFI(FusedAttnBackwardHandler); pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
dict["te_fused_attn_backward_ffi"] = fused_attn_backward_ffi; pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler));
return dict; return dict;
} }
PYBIND11_MODULE(transformer_engine_jax, m) { PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations); m.def("registrations", &Registrations);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor, pybind11::arg(), pybind11::arg(),
pybind11::arg(), pybind11::arg("act_num") = 0);
m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor, pybind11::arg(),
pybind11::arg(), pybind11::arg(), pybind11::arg(), pybind11::arg(),
pybind11::arg("act_num") = 0);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_fused_attn_backend", &GetFusedAttnBackend);
m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_cudnn_version", &GetCudnnRuntimeVersion); m.def("get_cudnn_version", &GetCudnnRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes); m.def("get_dact_dbias_quantize_workspace_sizes", &GetDActDBiasQuantizeWorkspaceSizes);
m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes); m.def("get_dbias_quantize_workspace_sizes", &GetDBiasQuantizeWorkspaceSizes);
m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes); m.def("get_norm_fwd_workspace_sizes", &GetNormForwardWorkspaceSizes);
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); m.def("get_norm_bwd_workspace_sizes", &GetNormBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
...@@ -191,6 +132,24 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -191,6 +132,24 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8); .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8);
pybind11::enum_<NVTE_Norm_Type>(m, "NVTE_Norm_Type", pybind11::module_local())
.value("LayerNorm", NVTE_Norm_Type::LayerNorm)
.value("RMSNorm", NVTE_Norm_Type::RMSNorm)
.export_values();
pybind11::enum_<NVTEScalingMode>(m, "NVTE_Scaling_Mode", pybind11::module_local())
.value("NVTE_DELAYED_TENSOR_SCALING", NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING)
.value("NVTE_MXFP8_1D_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING)
.value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING)
.export_values();
pybind11::enum_<transformer_engine::jax::QuantizeAxis>(m, "QuantizeAxis",
pybind11::module_local())
.value("ROWWISE", transformer_engine::jax::QuantizeAxis::ROWWISE)
.value("COLWISE", transformer_engine::jax::QuantizeAxis::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeAxis::ROWWISE_COLWISE)
.export_values();
} }
} // namespace jax } // namespace jax
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <cuda_runtime.h>
#include "extensions.h" #include "extensions.h"
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
...@@ -11,74 +12,131 @@ ...@@ -11,74 +12,131 @@
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
auto *input = buffers[0]; DType in_dtype, DType out_dtype) {
auto *amax = reinterpret_cast<float *>(buffers[1]); auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto *scale = reinterpret_cast<float *>(buffers[2]); auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto *scale_inv = reinterpret_cast<float *>(buffers[3]); auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto *output = buffers[4]; auto dbias_shape = std::vector<size_t>{hidden_size};
auto *amax_out = reinterpret_cast<float *>(buffers[5]);
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX Quantize primitive."); // Evil hack to specify TE impl
// Note: nvte_quantize_dbias chooses its internal impl based on what
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); // pointers are allocated, e.g. whether to output with column-wise
auto shape = desc.shape.to_vector(); // data. However, we don't have access to any allocated buffers in
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype); // this function. We pass a dummy pointer as a workaround.
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv); int temp = 0;
nvte_quantize(input_tensor.data(), output_tensor.data(), stream); auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
auto output_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), output_shape, out_dtype);
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_trans_shape);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
} }
Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf, Result_Type output_buf, Result_Type output_trans_buf,
Result_Type amax_out_buf) { Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf,
Result_Type amax_out_buf, Result_Type dbias_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum,
int64_t quantize_axis_enum, bool is_dbias) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization.");
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data()); auto const quantize_axis = static_cast<QuantizeAxis>(quantize_axis_enum);
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data()); auto *output_trans = output_trans_buf->untyped_data();
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX Quantize primitive."); auto *dbias = dbias_buf->untyped_data();
void *workspace = workspace_buf->untyped_data();
auto input_dims = input_buf.dimensions(); auto input_dims = input_buf.dimensions();
std::vector<size_t> shape(input_dims.begin(), input_dims.end()); auto workspace_dims = workspace_buf->dimensions();
auto input_tensor = TensorWrapper(input, shape, in_dtype); auto m = product(input_dims, 0, input_dims.size() - 1);
auto output_tensor = TensorWrapper(output, shape, out_dtype, amax_out, scale, scale_inv); auto n = input_dims.back();
auto input_shape = std::vector<size_t>{m, n};
nvte_quantize(input_tensor.data(), output_tensor.data(), stream); auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
std::vector<size_t> workspace_shape{workspace_dims.begin(), workspace_dims.end()};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(scaling_mode);
if (quantize_axis == QuantizeAxis::ROWWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
}
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax_out, 0, sizeof(float), stream);
output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1});
}
if (quantize_axis == QuantizeAxis::COLWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) {
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &colwise_scale_inv_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf;
output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()});
}
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
if (is_dbias) {
nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
workspace_tensor.data(), stream);
} else {
nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
}
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(QuantizeHandler, QuantizeFFI, XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
FFI::Bind() FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input .Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>(), // amax_out .Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("scaling_mode")
.Attr<int64_t>("q_axis")
.Attr<bool>("is_dbias"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *amax = reinterpret_cast<float *>(buffers[1]);
auto *scale = reinterpret_cast<float *>(buffers[2]);
auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto shape = desc.shape.to_vector();
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);
nvte_dequantize(input_tensor.data(), output_tensor.data(), stream);
}
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf) { Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
......
...@@ -12,103 +12,6 @@ ...@@ -12,103 +12,6 @@
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, shape, dtype);
auto output_tensor = TensorWrapper(output, shape, dtype);
nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), desc.scale_factor, stream);
}
void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *grad_output = buffers[0];
auto *softmax_output = buffers[1];
auto *dgrad = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
dgrad_tensor.data(), desc.scale_factor, stream);
}
void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *mask = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto io_shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
auto mask_shape = std::vector<size_t>{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, io_shape, dtype);
// Mask would be casted to uint8_t
auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
auto output_tensor = TensorWrapper(output, io_shape, dtype);
nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(), output_tensor.data(),
desc.scale_factor, stream);
}
void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
// The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax.
ScaledSoftmaxBackward(stream, buffers, opaque, opaque_len);
}
void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto attn_batch = desc.batch_size * desc.head_dim;
auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, shape, dtype);
auto output_tensor = TensorWrapper(output, shape, dtype);
nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
desc.scale_factor, stream);
}
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *grad_output = buffers[0];
auto *softmax_output = buffers[1];
auto *dgrad = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto attn_batch = desc.batch_size * desc.head_dim;
auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
nvte_scaled_upper_triang_masked_softmax_backward(grad_output_tensor.data(),
softmax_output_tensor.data(),
dgrad_tensor.data(), desc.scale_factor, stream);
}
#define SOFTMAX_COMMON_BLOCK(tensor_buf) \ #define SOFTMAX_COMMON_BLOCK(tensor_buf) \
auto dtype = convert_ffi_datatype_to_te_dtype((tensor_buf).element_type()); \ auto dtype = convert_ffi_datatype_to_te_dtype((tensor_buf).element_type()); \
auto tensor_dims = (tensor_buf).dimensions(); \ auto tensor_dims = (tensor_buf).dimensions(); \
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/transpose.h"
#include "extensions.h"
#include "transformer_engine/cast.h"
#include "xla/ffi/api/ffi.h"
namespace transformer_engine {
namespace jax {
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
void *output) {
auto input_shape = std::vector<size_t>{rows, cols};
auto output_shape = std::vector<size_t>{cols, rows};
auto input_tensor = TensorWrapper(input, input_shape, dtype);
auto transposed_tensor = TensorWrapper(output, output_shape, dtype);
nvte_transpose(input_tensor.data(), transposed_tensor.data(), stream);
}
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
void *input = buffers[0];
void *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto rows = desc.shape.dims[0];
auto cols = desc.shape.dims[1];
assert(desc.in_dtype == desc.out_dtype);
auto dtype = desc.out_dtype;
TransposeImpl(input, rows, cols, dtype, stream, output);
}
Error_Type TransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf,
int64_t transpose_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
void *input = input_buf.untyped_data();
void *output = output_buf->untyped_data();
auto input_dims = input_buf.dimensions();
if (transpose_axis < 0) transpose_axis += input_dims.size();
auto m = product(input_dims, 0, transpose_axis);
auto n = product(input_dims, transpose_axis, input_dims.size());
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype);
nvte_transpose(input_tensor.data(), output_tensor.data(), stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(TransposeHandler, TransposeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<int64_t>("transpose_axis"),
FFI_CudaGraph_Traits);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *input_cast = buffers[4];
auto *input_cast_trans = buffers[5];
float *amax_out = reinterpret_cast<float *>(buffers[6]);
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX CastTranspose primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto input_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv);
output_tensor.set_columnwise_data(input_cast_trans, desc.out_dtype, input_trans_shape);
output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
}
Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type amax_out_buf, int64_t transpose_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX CastTranspose primitive.");
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto input_dims = input_buf.dimensions();
if (transpose_axis < 0) transpose_axis += input_dims.size();
auto m = product(input_dims, 0, transpose_axis);
auto n = product(input_dims, transpose_axis, input_dims.size());
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = input_shape;
auto output_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CastTransposeHandler, CastTransposeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // output_trans
.Ret<Buffer_Type>() // amax_out
.Attr<int64_t>("transpose_axis"),
FFI_CudaGraph_Traits);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
// Evil hack to specify TE impl
// Note: nvte_quantize_dbias chooses its internal impl based on what
// pointers are allocated, e.g. whether to output with column-wise
// data. However, we don't have access to any allocated buffers in
// this function. We pass a dummy pointer as a workaround.
int temp = 0;
auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
auto output_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), output_shape, out_dtype);
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_trans_shape);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
TensorWrapper dummy_workspace;
nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
dummy_workspace.data(), nullptr);
auto work_shape = MakeShapeVector(dummy_workspace.shape());
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *output_trans = buffers[5];
auto *dbias = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
void *workspace_ptr = buffers[8];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DBiasCastTranspose primitive.");
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape);
output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
workspace.data(), stream);
}
Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type dbias_buf, Result_Type amax_out_buf,
Result_Type workspace_buf, int64_t transpose_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
auto *input = input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
void *workspace = workspace_buf->untyped_data();
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DBiasCastTranspose primitive.");
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto input_dims = input_buf.dimensions();
auto workspace_dims = workspace_buf->dimensions();
if (transpose_axis < 0) transpose_axis += input_dims.size();
auto m = product(input_dims, 0, transpose_axis);
auto n = product(input_dims, transpose_axis, input_dims.size());
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
workspace_tensor.data(), stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasCastTransposeHandler, DBiasCastTransposeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // output_trans
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("transpose_axis"),
FFI_CudaGraph_Traits);
} // namespace jax
} // namespace transformer_engine
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Dense layer transformation operations for Transformer Engine in JAX.
This module provides optimized dense layer transformation operations for transformer
architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations.
"""
from typing import Tuple, Sequence
from functools import partial
import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize import QuantizerSet, noop_quantizer_set
def dense(
x: jnp.ndarray,
kernel: jnp.ndarray,
bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
quantizer_set: QuantizerSet = noop_quantizer_set,
):
"""Perform dense layer transformation with optional quantization.
This function implements matrix multiplication with optional bias addition,
supporting quantization and custom contracting dimensions. It's optimized
for transformer architectures and supports automatic differentiation.
Args:
x: Input tensor
kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Transformed output tensor
"""
# Remove when tex.quantize() can handle quantizer=None
if quantizer_set == noop_quantizer_set:
output = tex.gemm(x, kernel, contracting_dims)
if bias is not None:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
else:
output = _dense(x, kernel, bias, contracting_dims, quantizer_set)
return output
@partial(jax.custom_vjp, nondiff_argnums=(3,))
def _dense(x, kernel, bias, contracting_dims, quantizer_set):
"""Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support
for custom vector-Jacobian product (VJP) for automatic differentiation.
Args:
x: Input tensor
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Transformed output tensor
"""
output, _ = _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set)
return output
def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
"""Forward pass rule for dense layer transformation.
Args:
x: Input tensor
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Tuple of (output, context) for backward pass
"""
x_contracting_dims, k_contracting_dims = contracting_dims
casted_x = tex.quantize(x, quantizer_set.x)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel)
# GEMM NN
output = tex.gemm(
casted_x.get_rowwise_tensor(),
casted_kernel.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
)
use_bias = bias is not None
if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
ctx = (
casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None,
casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None,
x.shape,
kernel.shape,
use_bias,
quantizer_set,
)
return output, ctx
def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation.
Args:
contracting_dims: Contracting dimensions specification
ctx: Context from forward pass
grad: Gradient from upstream
Returns:
Tuple of gradients with respect to inputs
"""
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
(
colwise_casted_x,
rowwise_casted_kernel,
x_shape,
kernel_shape,
use_bias,
quantizer_set,
) = ctx
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad)
# GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
)
# k_non_contracting_dims
k_constracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
)
dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel,
(g_constracting_dim, k_constracting_dim),
)
# GEMM TN
# x_non_contracting_dims
g_constracting_dim = x_constracting_dim = tuple(
range(0, len(x_shape) - len(fwd_x_contracting_dims))
)
wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim)
)
return dgrad, wgrad, dbias, quantizer_set
_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)
def grouped_dense(
x_list,
kernel_list,
bias_list,
contracting_dims_list,
quantizer_set_list=None,
):
"""
Perform grouped_dense layer transformation with optional quantization.
"""
output_list = _grouped_dense(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
)
return output_list
@partial(jax.custom_vjp, nondiff_argnums=(3,))
def _grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
output_list, _ = _grouped_dense_fwd_rule(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
)
return output_list
def _grouped_dense_fwd_rule(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
):
use_bias = bias_list is not None
output_list = []
x_rowwise_list = []
x_colwise_list = []
kernel_colwise_list = []
kernel_rowwise_list = []
x_shape_list = []
kernel_shape_list = []
if quantizer_set_list is None:
x_rowwise_list = x_list
x_colwise_list = x_list
kernel_colwise_list = kernel_list
kernel_rowwise_list = kernel_list
x_shape_list = [x.shape for x in x_list]
kernel_shape_list = [kernel.shape for kernel in kernel_list]
else:
for i in range(len(x_list)): # pylint: disable=consider-using-enumerate
q_x = tex.quantize(x_list[i], quantizer_set_list[i].x)
q_kernel = tex.quantize(kernel_list[i], quantizer_set_list[i].kernel)
x_rowwise_list.append(q_x.get_rowwise_tensor())
x_colwise_list.append(q_x.get_colwise_tensor())
kernel_colwise_list.append(q_kernel.get_colwise_tensor())
kernel_rowwise_list.append(q_kernel.get_rowwise_tensor())
x_shape_list.append(x_rowwise_list[-1].data.shape)
kernel_shape_list.append(kernel_rowwise_list[-1].data.shape)
output_list = tex.grouped_gemm(
x_rowwise_list, kernel_colwise_list, contracting_dims_list, bias_list
)
ctx = (
x_colwise_list,
kernel_rowwise_list,
x_shape_list,
kernel_shape_list,
use_bias,
quantizer_set_list,
)
return output_list, ctx
def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list):
(
colwise_x_list,
rowwise_kernel_list,
x_shape_list,
kernel_shape_list,
use_bias,
quantizer_set_list,
) = ctx
group_size = len(grad_list)
dbias_list = []
grad_rowwise_list = []
grad_colwise_list = []
dgrad_contracting_dims_list = []
wgrad_contracting_dims_list = []
for i in range(group_size):
grad = grad_list[i]
x_shape = x_shape_list[i]
kernel_shape = kernel_shape_list[i]
fwd_contracting_dims = contracting_dims_list[i]
if quantizer_set_list is None:
casted_grad = grad
dbias = tex.quantization._jax_dbias(grad)
grad_rowwise_list.append(grad)
grad_colwise_list.append(grad)
else:
quantizer_set = quantizer_set_list[i]
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad
)
grad_rowwise_list.append(casted_grad.get_rowwise_tensor())
grad_colwise_list.append(casted_grad.get_colwise_tensor())
dbias_list.append(dbias)
# GEMM NT
fwd_x_contracting_dims, fwd_k_contracting_dims = fwd_contracting_dims
g_contracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
)
k_contracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
)
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_contracting_dims_list.append(dgrad_contracting_dims)
# GEMM TN
g_contracting_dim = x_contracting_dim = tuple(
range(0, len(x_shape) - len(fwd_x_contracting_dims))
)
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_contracting_dims_list.append(wgrad_contracting_dims)
dgrad_list = tex.grouped_gemm(
grad_rowwise_list, rowwise_kernel_list, dgrad_contracting_dims_list
)
wgrad_list = tex.grouped_gemm(colwise_x_list, grad_colwise_list, wgrad_contracting_dims_list)
return dgrad_list, wgrad_list, dbias_list, quantizer_set_list
_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te modules"""
from typing import List, Tuple, Sequence
from functools import partial
import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .fp8 import FP8Helper, FP8MetaPackage
Precision = jax.lax.Precision
def type_safe_dot_general(
x,
kernel,
fp8_meta_pkg: FP8MetaPackage = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
) -> jnp.ndarray:
"""
Type safe dot_general, including FP8.
"""
if fp8_meta_pkg is None:
assert x.dtype == kernel.dtype, f"lhs dtype = {x.dtype}, rhs dtype = {kernel.dtype}"
return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ())))
amax_list = fp8_meta_pkg.amax_list
scale_list = fp8_meta_pkg.scale_list
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
return _fp8_dot(x, kernel, amax_list, scale_list, fwd_dtype, bwd_dtype, contracting_dims)
def quantize(x, q_dtype, scale):
"""
Quantize with scale.
"""
updated_amax = jnp.max(jnp.abs(x)).astype(scale.dtype)
dtype_max = (jnp.finfo(q_dtype).max).astype(x.dtype)
scale = scale.astype(x.dtype)
clipped_scaled_x = jnp.clip((x * scale), -dtype_max, dtype_max)
return clipped_scaled_x.astype(q_dtype), updated_amax
def dequantize(x, dq_dtype, scale_inv):
"""
Dequantize with scale_inv.
"""
return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)
# Apply jit to guarantee correctness of FP8 GEMM.
@partial(jax.jit, static_argnums=(4, 5, 6))
def fp8_dot_impl(
q_lhs: jnp.ndarray,
q_rhs: jnp.ndarray,
lhs_scale_inv: jnp.ndarray,
rhs_scale_inv: jnp.ndarray,
ctype: jnp.dtype, # computing type
contracting_dims: Tuple[Sequence[int], Sequence[int]],
precision: Precision = None,
):
"""
FP8 GEMM for XLA pattern match
"""
dim_nums = (contracting_dims, ((), ()))
lhs = dequantize(q_lhs, ctype, lhs_scale_inv)
rhs = dequantize(q_rhs, ctype, rhs_scale_inv)
return jax.lax.dot_general(lhs, rhs, dim_nums, precision=precision)
def get_precision_of_fp8_dot(enable_2xACC: bool):
"""
Get Precision of FP8 DOT.
"""
return jax.lax.Precision.HIGHEST if enable_2xACC else jax.lax.Precision.DEFAULT
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6))
def _fp8_dot(
x: jnp.ndarray,
kernel: jnp.ndarray,
amax_list: List[jnp.ndarray],
scale_list: List[jnp.ndarray],
fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
):
output, _ = _fp8_dot_fwd_rule(
x, kernel, amax_list, scale_list, fwd_dtype, bwd_dtype, contracting_dims
)
return output
def _fp8_dot_fwd_rule(
x,
kernel,
amax_list,
scale_list,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims,
):
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair(
*amax_list, *scale_list
)
amax_list = maybe_fm32_to_fp32(*amax_list)
scale_list = maybe_fm32_to_fp32(*scale_list)
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
x_shape_suf = x.shape[min(lhs_contracting_dims) :]
kernel_shape_pre = kernel.shape[: max(rhs_contracting_dims) + 1]
assert x_shape_suf == kernel_shape_pre
fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale(
amax_list, scale_list, fp8_dtype_list
)
amax_list = FP8MetaPackage.update_amax_list(amax_list)
x_scale = scale_list[FP8MetaPackage.INPUT_IDX]
x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_x, updated_x_amax = quantize(x, fwd_dtype, x_scale)
kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX]
kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel, updated_kernel_amax = quantize(kernel, fwd_dtype, kernel_scale)
output = fp8_dot_impl(
casted_x,
casted_kernel,
x_scale_inv,
kernel_scale_inv,
x.dtype,
(lhs_contracting_dims, rhs_contracting_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
)
ctx = (
casted_x,
casted_kernel,
amax_list,
scale_list,
scale_inv_list,
updated_x_amax,
updated_kernel_amax,
x.shape,
kernel.shape,
maybe_fp32_to_fm32,
)
return output, ctx
def _fp8_dot_bwd_rule(
fwd_dtype, bwd_dtype, contracting_dims, ctx, grad
): # pylint: disable=unused-argument
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
(
casted_x,
casted_kernel,
amax_list,
scale_list,
scale_inv_list,
updated_x_amax,
updated_kernel_amax,
x_shape,
kernel_shape,
maybe_fp32_to_fm32,
) = ctx
grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1]
grad_scale = scale_list[FP8MetaPackage.GRAD_IDX]
grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_IDX]
casted_grad, casted_grad_t, updated_grad_amax = tex.cast_transpose(
grad,
grad_amax,
grad_scale,
grad_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=min(lhs_contracting_dims),
)
x_constracting_dim = tuple(range(0, len(x_shape) - len(lhs_contracting_dims)))
gt_constracting_dim = tuple(range(grad.ndim - len(x_constracting_dim), grad.ndim))
x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
wgrad = fp8_dot_impl(
casted_x,
casted_grad_t,
x_scale_inv,
grad_scale_inv,
grad.dtype,
(x_constracting_dim, gt_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
)
g_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim)
)
k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape)))
kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
dgrad = fp8_dot_impl(
casted_grad,
casted_kernel,
grad_scale_inv,
kernel_scale_inv,
grad.dtype,
(g_constracting_dim, k_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
)
amax_list[FP8MetaPackage.INPUT_IDX] = (
amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax)
)
amax_list[FP8MetaPackage.WEIGHT_IDX] = (
amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax)
)
amax_list[FP8MetaPackage.GRAD_IDX] = (
amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
)
amax_list = maybe_fp32_to_fm32(*amax_list)
scale_list = maybe_fp32_to_fm32(*scale_list)
return dgrad, wgrad, amax_list, scale_list
_fp8_dot.defvjp(_fp8_dot_fwd_rule, _fp8_dot_bwd_rule)
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for JAX""" """Transformer Engine bindings for JAX"""
from .module import DenseGeneral, LayerNorm from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase from .module import LayerNormDenseGeneral, LayerNormMLP
from .transformer import extend_logical_axis_rules from .transformer import extend_logical_axis_rules
from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType from .transformer import TransformerLayer, TransformerLayerType
...@@ -13,7 +13,6 @@ __all__ = [ ...@@ -13,7 +13,6 @@ __all__ = [
"LayerNorm", "LayerNorm",
"LayerNormDenseGeneral", "LayerNormDenseGeneral",
"LayerNormMLP", "LayerNormMLP",
"TransformerEngineBase",
"extend_logical_axis_rules", "extend_logical_axis_rules",
"DotProductAttention", "DotProductAttention",
"MultiHeadAttention", "MultiHeadAttention",
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
""" """
Wrapper module for Transformer related layers with FP8 support. Wrapper module for Transformer related layers with FP8 support.
""" """
import functools from functools import reduce
import operator import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union
...@@ -17,14 +17,17 @@ from jax import nn as jax_nn ...@@ -17,14 +17,17 @@ from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from ..dot import type_safe_dot_general from ..dense import dense
from ..fp8 import FP8Helper, FP8MetaPackage
from ..layernorm import canonicalize_layernorm_type from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm, layernorm_fp8_dot from ..layernorm import layernorm
from ..layernorm_mlp import fused_layernorm_fp8_mlp, activation_lu from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
from ..activation import activation
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
from ..cpp_extensions import is_softmax_kernel_available from ..cpp_extensions import is_softmax_kernel_available
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -57,17 +60,24 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga ...@@ -57,17 +60,24 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
def _create_layernorm_parameters( def _create_layernorm_parameters(
layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, input_dtype, dtype norm_type,
shape,
scale_init,
scale_axes,
bias_init,
bias_axes,
input_dtype,
dtype,
): ):
scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes) scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes)
scale = scale.astype(input_dtype) scale = scale.astype(input_dtype)
layernorm_type = canonicalize_layernorm_type(layernorm_type) norm_type = canonicalize_norm_type(norm_type)
if layernorm_type == "layernorm": if norm_type == "layernorm":
bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes) bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes)
bias = bias.astype(input_dtype) bias = jnp.asarray(bias, input_dtype)
else: else:
assert layernorm_type == "rmsnorm" assert norm_type == "rmsnorm"
bias = None bias = None
return scale, bias return scale, bias
...@@ -315,7 +325,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -315,7 +325,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
x, x,
scale, scale,
ln_bias, ln_bias,
layernorm_type=self.layernorm_type, norm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon, epsilon=self.epsilon,
) )
...@@ -328,49 +338,44 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -328,49 +338,44 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
Base class of transformer engine Base class of transformer engine
""" """
@staticmethod def generate_quantizer_set(self, postfix: str = ""):
def generate_fp8_meta_set(postfix: str) -> FP8MetaPackage:
""" """
Generate a set of FP8 meta for a GEMM. Generate a set of FP8 meta for a GEMM.
""" """
input_name_post_fix = f"_i_{postfix}" def generate_quantize_meta(quantizer_name: str):
weight_name_post_fix = f"_w_{postfix}" scale = self.variable(
grad_name_post_fix = f"_g_{postfix}" QuantizeConfig.COLLECTION_NAME,
f"{quantizer_name}{postfix}_scale",
def generate_a_set(target_postfix):
amax = nn_partitioning.variable_with_axes(
FP8Helper.FP8_COLLECTION_NAME,
f"{FP8Helper.FP8_AMAX_NAME}{target_postfix}",
jnp.zeros,
(FP8Helper.AMAX_HISTORY_LEN,),
jnp.float32,
axes=(None,),
)
scale = nn_partitioning.variable_with_axes(
FP8Helper.FP8_COLLECTION_NAME,
f"{FP8Helper.FP8_SCALE_NAME}{target_postfix}",
jnp.ones, jnp.ones,
(1,), (1,),
jnp.float32, jnp.float32,
axes=(None,), ).value
) amax_history = self.variable(
QuantizeConfig.COLLECTION_NAME,
return amax.value, scale.value f"{quantizer_name}{postfix}_amax_history",
jnp.zeros,
input_amax, input_scale = generate_a_set(input_name_post_fix) (QuantizeConfig.AMAX_HISTORY_LEN,),
weight_amax, weight_scale = generate_a_set(weight_name_post_fix) jnp.float32,
grad_amax, grad_scale = generate_a_set(grad_name_post_fix) ).value
return QuantizeMeta(scale=scale, amax_history=amax_history)
if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad")
quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
kwargs = {"quantize_meta_set": quantize_meta_set}
else:
kwargs = {}
return FP8MetaPackage( quantizer_set = QuantizerFactory.create_set(**kwargs)
input_amax, input_scale, weight_amax, weight_scale, grad_amax, grad_scale return quantizer_set
)
class DenseGeneral(TransformerEngineBase): class DenseGeneral(TransformerEngineBase):
r""" r"""
Applies a linear transformation to the incoming data :math:`y = xA^T + b`. Applies a dense layer transformation to the incoming data :math:`y = xA^T + b`.
Parameters Parameters
---------- ----------
...@@ -392,7 +397,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -392,7 +397,7 @@ class DenseGeneral(TransformerEngineBase):
The name of axes used to shard bias with a corresponding mesh, The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`. only used when :attr:`use_bias=True`.
enable_low_rank_adaptation: bool, default = False enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer. Indicate whether to enable low rank adaptation for each dense layer.
low_rank_adaptation_dim: int, default = 32 low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True` :attr:`enable_low_rank_adaptation=True`
...@@ -435,7 +440,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -435,7 +440,7 @@ class DenseGeneral(TransformerEngineBase):
@nn.compact @nn.compact
def __call__(self, inputs: Array) -> Array: def __call__(self, inputs: Array) -> Array:
""" """
Apply the linear transformation to the input. Apply the dense layer transformation to the input.
Parameters Parameters
---------- ----------
...@@ -455,28 +460,29 @@ class DenseGeneral(TransformerEngineBase): ...@@ -455,28 +460,29 @@ class DenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, inputs.ndim) axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
) )
if not FP8Helper.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
kernel_compute_shape = (
reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1),
reduce(operator.mul, features, 1),
)
kernel = jnp.reshape(kernel, kernel_compute_shape)
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
) )
bias = bias.astype(input_dtype) bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype)
else: else:
bias = None bias = None
quantizer_set = self.generate_quantizer_set()
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
fp8_meta_pkg = None y = dense(
if FP8Helper.is_fp8_enabled(): inputs, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set
fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
y = type_safe_dot_general(
inputs, kernel, fp8_meta_pkg=fp8_meta_pkg, contracting_dims=(axis, contract_ind)
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
...@@ -486,7 +492,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -486,7 +492,7 @@ class DenseGeneral(TransformerEngineBase):
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
lora_a_kernel_init_shape = ( lora_a_kernel_init_shape = (
kernel_param_shape[0], kernel_compute_shape[0],
*features[:-1], *features[:-1],
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
...@@ -521,19 +527,20 @@ class DenseGeneral(TransformerEngineBase): ...@@ -521,19 +527,20 @@ class DenseGeneral(TransformerEngineBase):
y += jnp.reshape(bias, bias_shape) y += jnp.reshape(bias, bias_shape)
assert y.dtype == input_dtype assert y.dtype == input_dtype
y = y.reshape(*inputs.shape[: self.axis], *features)
return y return y
class LayerNormDenseGeneral(TransformerEngineBase): class LayerNormDenseGeneral(TransformerEngineBase):
r""" r"""
Applies layer normalization followed by linear transformation to the incoming data. Applies layer normalization followed by dense layer transformation to the incoming data.
Parameters Parameters
---------- ----------
features : Union[Iterable[int], int] features : Union[Iterable[int], int]
The hidden size of each output sample. The hidden size of each output sample.
enable_layernorm: bool, default = True enable_layernorm: bool, default = True
Indicate whether to enable layer normalization before linear transformation. Indicate whether to enable layer normalization before dense layer transformation.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization. Indicate the type of layer normalization.
epsilon : float, default = 1e-6 epsilon : float, default = 1e-6
...@@ -582,7 +589,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -582,7 +589,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
Indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set False, return None as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer. Indicate whether to enable low rank adaptation for each dense layer.
low_rank_adaptation_dim: int, default = 32 low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True` :attr:`enable_low_rank_adaptation=True`
...@@ -650,12 +657,13 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -650,12 +657,13 @@ class LayerNormDenseGeneral(TransformerEngineBase):
self.scale_init, self.scale_init,
self.zero_centered_gamma, self.zero_centered_gamma,
) )
self.quantizer_set = QuantizerFactory.create_set()
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
def __call__(self, inputs: Array) -> Array: def __call__(self, inputs: Array) -> Array:
""" """
Apply layer normalization to the input followed by a linear transformation. Apply layer normalization to the input followed by a dense layer transformation.
Parameters Parameters
---------- ----------
...@@ -674,8 +682,10 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -674,8 +682,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
input_dtype = inputs.dtype input_dtype = inputs.dtype
ln_output = None ln_output = None
quantizer_set = self.generate_quantizer_set()
fuse_layernorm = ( fuse_layernorm = (
FP8Helper.is_fp8_enabled() QuantizeConfig.is_fp8_enabled()
and not self.return_layernorm_output and not self.return_layernorm_output
and self.enable_layernorm and self.enable_layernorm
) )
...@@ -702,7 +712,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -702,7 +712,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
inputs, inputs,
scale, scale,
ln_bias, ln_bias,
layernorm_type=self.layernorm_type, norm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon, epsilon=self.epsilon,
) )
...@@ -722,37 +732,35 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -722,37 +732,35 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, y.ndim) axis = _normalize_axes(axis, y.ndim)
kernel_shape = tuple(y.shape[ax] for ax in axis) + features kernel_shape = tuple(y.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
) )
if not FP8Helper.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
kernel_compute_shape = (
reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1),
reduce(operator.mul, features, 1),
)
kernel = jnp.reshape(kernel, kernel_compute_shape)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
fp8_meta_pkg = None
if FP8Helper.is_fp8_enabled():
fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
if fuse_layernorm: if fuse_layernorm:
z = layernorm_fp8_dot( z = layernorm_dense(
y, y,
kernel, kernel,
scale, scale,
ln_bias, ln_bias,
fp8_meta_pkg, norm_type=self.layernorm_type,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon, epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes, layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_input_axes, dot_input_axes=self.dot_input_axes,
quantizer_set=quantizer_set,
) )
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
z = type_safe_dot_general( z = dense(y, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set)
y, kernel, fp8_meta_pkg=fp8_meta_pkg, contracting_dims=(axis, contract_ind)
)
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
lora_a_kernel_shape = ( lora_a_kernel_shape = (
...@@ -761,7 +769,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -761,7 +769,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
lora_a_kernel_init_shape = ( lora_a_kernel_init_shape = (
kernel_param_shape[0], kernel_compute_shape[0],
*features[:-1], *features[:-1],
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
...@@ -796,7 +804,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -796,7 +804,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
) )
bias = bias.astype(input_dtype) bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype)
if bias is not None: if bias is not None:
bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
...@@ -805,21 +813,22 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -805,21 +813,22 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.depth_scaling is not None: if self.depth_scaling is not None:
z = z / self.depth_scaling z = z / self.depth_scaling
assert z.dtype == input_dtype assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
z = z.reshape(*inputs.shape[: self.axis], *features)
return z, ln_output # dense_output, layer_norm_output return z, ln_output # dense_output, layer_norm_output
class LayerNormMLP(TransformerEngineBase): class LayerNormMLP(TransformerEngineBase):
r""" r"""
Applies layer normalization on the input followed by the MLP module, Applies layer normalization on the input followed by the MLP module,
consisting of 2 successive linear transformations, separated by given activations. consisting of 2 successive dense layer transformations, separated by given activations.
Parameters Parameters
---------- ----------
intermediate_dim: int, default = 2048 intermediate_dim: int, default = 2048
Intermediate size to which input samples are projected. Intermediate size to which input samples are projected.
enable_layernorm: bool, default = True enable_layernorm: bool, default = True
Indicate whether to enable layer normalization before linear transformation. Indicate whether to enable layer normalization before dense layer transformation.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization. Indicate the type of layer normalization.
epsilon : float, default = 1e-6 epsilon : float, default = 1e-6
...@@ -851,14 +860,14 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -851,14 +860,14 @@ class LayerNormMLP(TransformerEngineBase):
Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default = kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing the weights of both linear transformations. Used for initializing the weights of both dense layer transformations.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp') kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
The name of axes used to shard the weights with a corresponding mesh for The name of axes used to shard the weights with a corresponding mesh for
the weight of the first linear transformations. the weight of the first dense layer transformation.
kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed') kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
The name of axes used to shard the weights with a corresponding mesh for The name of axes used to shard the weights with a corresponding mesh for
the weight of the second linear transformations. the weight of the second dense layer transformation.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting. Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias. If set to False, the layer will not learn an additive bias.
...@@ -867,17 +876,17 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -867,17 +876,17 @@ class LayerNormMLP(TransformerEngineBase):
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
bias_axes_1: Tuple[str, ...], default = ('mlp',) bias_axes_1: Tuple[str, ...], default = ('mlp',)
The name of axes used to shard bias with a corresponding mesh for The name of axes used to shard bias with a corresponding mesh for
the weight of the first linear transformations. the weight of the first dense layer transformation.
Only used when :attr:`use_bias=True`. Only used when :attr:`use_bias=True`.
bias_axes_2: Tuple[str, ...], default = ('embed',) bias_axes_2: Tuple[str, ...], default = ('embed',)
The name of axes used to shard bias with a corresponding mesh for The name of axes used to shard bias with a corresponding mesh for
the weight of the second linear transformations. the weight of the second dense layer transformation.
Only used when :attr:`use_bias=True`. Only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True return_layernorm_output: bool, default = True
Indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set False, return None as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('relu',) activations: Sequence[Union[str, Callable]], default = ('relu',)
The sequence of activation functions to apply after the first linear transformation. The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
intermediate_dropout_rng_name: str, default = 'dropout' intermediate_dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
...@@ -886,7 +895,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -886,7 +895,7 @@ class LayerNormMLP(TransformerEngineBase):
intermediate_hidden_dropout_dims: Sequence[int], default = () intermediate_hidden_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden Dimensions that will share the same dropout mask for hidden
enable_low_rank_adaptation: bool, default = False enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer. Indicate whether to enable low rank adaptation for each dense layer.
low_rank_adaptation_dim: int, default = 32 low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`. :attr:`enable_low_rank_adaptation=True`.
...@@ -980,12 +989,16 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -980,12 +989,16 @@ class LayerNormMLP(TransformerEngineBase):
The output tensors of layer normalization. The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None. If :attr:`return_layernorm_output=False`, then this would be None.
""" """
ffn1_quantizer_set = self.generate_quantizer_set("_0")
ffn2_quantizer_set = self.generate_quantizer_set("_1")
input_dtype = inputs.dtype input_dtype = inputs.dtype
ln_output = None ln_output = None
# TODO(Phuong): use fuse_layernorm for high-precision
# when NoOpQuantizer and Tensor are implemented
fuse_layernorm = ( fuse_layernorm = (
FP8Helper.is_fp8_enabled() QuantizeConfig.is_fp8_enabled()
and not self.return_layernorm_output and not self.return_layernorm_output
and self.enable_layernorm and self.enable_layernorm
) )
...@@ -1012,7 +1025,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1012,7 +1025,6 @@ class LayerNormMLP(TransformerEngineBase):
use_fused_layernorm_mlp = ( use_fused_layernorm_mlp = (
fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3 fuse_layernorm and is_act_implemented and self.intermediate_dropout_rate < 1e-3
) )
# LayerNorm # LayerNorm
if self.enable_layernorm: if self.enable_layernorm:
assert self.axis == -1 # Only support axis == -1 at this moment assert self.axis == -1 # Only support axis == -1 at this moment
...@@ -1036,7 +1048,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1036,7 +1048,7 @@ class LayerNormMLP(TransformerEngineBase):
inputs, inputs,
scale, scale,
ln_bias, ln_bias,
layernorm_type=self.layernorm_type, norm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon, epsilon=self.epsilon,
) )
...@@ -1056,18 +1068,9 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1056,18 +1068,9 @@ class LayerNormMLP(TransformerEngineBase):
kernels.append(self.kernel_init(init_key, *init_args)) kernels.append(self.kernel_init(init_key, *init_args))
return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype) return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
wi_fp8_meta_pkg = None
wo_fp8_meta_pkg = None
if FP8Helper.is_fp8_enabled():
wi_fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
wo_fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("1")
num_activations = len(normalized_acts) num_activations = len(normalized_acts)
axis = _canonicalize_tuple(self.axis) axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim) axis = _normalize_axes(axis, y.ndim)
intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
kernel_1_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim) kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes( kernel_1 = nn_partitioning.param_with_axes(
"wi_kernel", "wi_kernel",
...@@ -1078,98 +1081,109 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1078,98 +1081,109 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype, self.dtype,
axes=self.kernel_axes_1, axes=self.kernel_axes_1,
) )
kernel_1 = jnp.reshape(kernel_1, kernel_1_shape) kernel_1_compute_shape = (
if not FP8Helper.is_fp8_enabled(): reduce(operator.mul, [y.shape[ax] for ax in axis], 1),
num_activations * self.intermediate_dim,
)
kernel_1 = jnp.reshape(kernel_1, kernel_1_compute_shape)
if not QuantizeConfig.is_fp8_enabled():
kernel_1 = kernel_1.astype(input_dtype) kernel_1 = kernel_1.astype(input_dtype)
hidden_size = inputs.shape[-1] hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size) hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple))
kernel_2 = nn_partitioning.param_with_axes( kernel_2 = nn_partitioning.param_with_axes(
"wo_kernel", "wo_kernel",
self.kernel_init, self.kernel_init,
kernel_2_param_shape, kernel_2_shape,
self.dtype, self.dtype,
axes=self.kernel_axes_2, axes=self.kernel_axes_2,
) )
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) kernel_2_compute_shape = (
if not FP8Helper.is_fp8_enabled(): self.intermediate_dim,
reduce(operator.mul, hidden_size_tuple, 1),
)
kernel_2 = jnp.reshape(kernel_2, kernel_2_compute_shape)
if not QuantizeConfig.is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype) kernel_2 = kernel_2.astype(input_dtype)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
if self.use_bias:
bias_1_shape = num_activations * self.intermediate_dim
bias_1 = nn_partitioning.param_with_axes(
"wi_bias",
self.bias_init,
bias_1_shape,
self.dtype,
axes=self.bias_axes_1,
)
bias_1 = bias_1.reshape(kernel_1_compute_shape[-1]).astype(input_dtype)
bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes(
"wo_bias",
self.bias_init,
bias_2_shape,
self.dtype,
axes=self.bias_axes_2,
)
bias_2 = bias_2.reshape(kernel_2_compute_shape[-1]).astype(input_dtype)
else:
bias_1 = None
bias_2 = None
ffn1_ckpt_name = "ffn1" ffn1_ckpt_name = "ffn1"
ffn2_ckpt_name = "ffn2" ffn2_ckpt_name = "ffn2"
if use_fused_layernorm_mlp: if use_fused_layernorm_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment assert self.axis == -1 # Only support axis = =-1 at this moment
if self.use_bias: out = layernorm_mlp(
bias_1_shape = intermediate_dim
bias_1 = nn_partitioning.param_with_axes(
"wi_bias",
self.bias_init,
bias_1_shape,
self.dtype,
axes=self.bias_axes_1,
)
bias_1 = bias_1.astype(input_dtype)
bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes(
"wo_bias",
self.bias_init,
bias_2_shape,
self.dtype,
axes=self.bias_axes_2,
)
bias_2 = bias_2.astype(input_dtype)
else:
bias_1 = None
bias_2 = None
out = fused_layernorm_fp8_mlp(
y, y,
scale, scale,
ln_bias, ln_bias,
[kernel_1, kernel_2], [kernel_1, kernel_2],
[bias_1, bias_2], [bias_1, bias_2],
[wi_fp8_meta_pkg, wo_fp8_meta_pkg],
self.layernorm_type, self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon, epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes, norm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes, dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes, dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name, ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name, ffn2_ckpt_name=ffn2_ckpt_name,
activation_type=normalized_acts, activation_type=normalized_acts,
use_bias=self.use_bias, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
) )
out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
else: # not use_fused_ln_geglu_mlp else: # not use_fused_ln_geglu_mlp
# DenseGeneral 1 # DenseGeneral 1
if fuse_layernorm: if fuse_layernorm:
x = layernorm_fp8_dot( x = layernorm_dense(
y, y,
kernel_1, kernel_1,
scale, scale,
ln_bias, ln_bias,
wi_fp8_meta_pkg, norm_type=self.layernorm_type,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon, epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes, layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_1_input_axes, dot_input_axes=self.dot_1_input_axes,
quantizer_set=ffn1_quantizer_set,
) )
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes) y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
x = type_safe_dot_general( x = dense(
y, kernel_1, fp8_meta_pkg=wi_fp8_meta_pkg, contracting_dims=(axis, contract_ind) y,
kernel_1,
contracting_dims=(axis, contract_ind),
quantizer_set=ffn1_quantizer_set,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
wi_lora_a_kernel_shape = ( wi_lora_a_kernel_shape = (
*kernel_1_shape[: len(axis)], kernel_1_compute_shape[0],
num_activations, num_activations,
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
...@@ -1187,7 +1201,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1187,7 +1201,7 @@ class LayerNormMLP(TransformerEngineBase):
"wi_lora_a_kernel", "wi_lora_a_kernel",
kernel_1_init, kernel_1_init,
num_activations, num_activations,
-2, -1,
wi_lora_a_kernel_init_each_shape, wi_lora_a_kernel_init_each_shape,
self.dtype, self.dtype,
axes=wi_lora_a_kernel_axes, axes=wi_lora_a_kernel_axes,
...@@ -1213,37 +1227,25 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1213,37 +1227,25 @@ class LayerNormMLP(TransformerEngineBase):
x += _apply_low_rank_adaptation( x += _apply_low_rank_adaptation(
y, y,
axis, axis,
intermediate_dim, num_activations * self.intermediate_dim,
wi_lora_a_kernel, wi_lora_a_kernel,
wi_lora_b_kernel, wi_lora_b_kernel,
self.low_rank_adaptation_alpha, self.low_rank_adaptation_alpha,
) )
bias_1 = None
if self.use_bias: if self.use_bias:
bias_1 = nn_partitioning.param_with_axes(
"wi_bias",
self.bias_init,
intermediate_dim,
self.dtype,
axes=self.bias_axes_1,
)
bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
bias_1 = bias_1.astype(input_dtype)
x += jnp.reshape(bias_1, bias_1_shape) x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name) x = checkpoint_name(x, ffn1_ckpt_name)
if is_act_implemented: if is_act_implemented:
z = activation_lu(x, normalized_acts) z = activation(x, normalized_acts)
else: else:
activations = [] activations = []
x = jnp.split(x, num_activations, axis=-2) x = jnp.split(x, num_activations, axis=-1)
for idx, act_fn in enumerate(normalized_acts): for idx, act_fn in enumerate(normalized_acts):
x_i = _convert_to_activation_function(act_fn)(x[idx]) x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i) activations.append(x_i)
z = functools.reduce(operator.mul, activations) z = reduce(operator.mul, activations)
# Remove act axis
z = jnp.reshape(z, (*z.shape[:-2], -1))
z = z.astype(input_dtype) z = z.astype(input_dtype)
z = nn.Dropout( z = nn.Dropout(
...@@ -1256,8 +1258,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1256,8 +1258,8 @@ class LayerNormMLP(TransformerEngineBase):
z = z.astype(input_dtype) z = z.astype(input_dtype)
# DenseGeneral 2 # DenseGeneral 2
out = type_safe_dot_general( out = dense(
z, kernel_2, fp8_meta_pkg=wo_fp8_meta_pkg, contracting_dims=(axis, contract_ind) z, kernel_2, contracting_dims=(axis, contract_ind), quantizer_set=ffn2_quantizer_set
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
...@@ -1292,16 +1294,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1292,16 +1294,7 @@ class LayerNormMLP(TransformerEngineBase):
self.low_rank_adaptation_alpha, self.low_rank_adaptation_alpha,
) )
bias_2 = None
if self.use_bias: if self.use_bias:
bias_2 = nn_partitioning.param_with_axes(
"wo_bias",
self.bias_init,
(hidden_size,),
self.dtype,
axes=self.bias_axes_2,
)
bias_2 = bias_2.astype(input_dtype)
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, ffn2_ckpt_name) out = checkpoint_name(out, ffn2_ckpt_name)
......
...@@ -638,7 +638,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -638,7 +638,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
else: else:
assert qkv_layout.is_separate() assert qkv_layout.is_separate()
assert sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray) assert sequence_descriptor is None or isinstance(
sequence_descriptor, (jnp.ndarray, np.ndarray)
)
x = _UnfusedDotProductAttention( x = _UnfusedDotProductAttention(
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
...@@ -928,7 +930,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -928,7 +930,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation. The data type used to allocate the initial parameters.
fuse_qkv_params: bool, default = True fuse_qkv_params: bool, default = True
If set to True, this module exposes a single fused If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for parameter for query-key-value for self-attention and key-value for
...@@ -1788,6 +1790,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1788,6 +1790,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
outputs: jax.numpy.ndarray outputs: jax.numpy.ndarray
Output tensors. Output tensors.
""" """
input_dtype = inputs.dtype input_dtype = inputs.dtype
assert ( assert (
self.layer_type in TransformerLayerType self.layer_type in TransformerLayerType
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Helper module for fp8 meta management
"""
from contextlib import contextmanager
from enum import Enum
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.linen import fp8_ops
from transformer_engine_jax import DType
from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import (
get_cuda_version,
get_device_compute_capability,
)
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import MeshResource
_is_fp8_available = None
_reason_for_no_fp8 = ""
Collection = Union[Dict, FrozenDict]
def _check_fp8_support(gpu_id) -> Tuple[bool, str]:
"""Return if fp8 support is available"""
gpu_arch = get_device_compute_capability(gpu_id)
if gpu_arch >= 90: # hopper and above
return True, ""
if gpu_arch < 89: # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if get_cuda_version() < 12010:
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
return True, ""
def is_fp8_available(gpu_id=None) -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if gpu_id is not None:
return _check_fp8_support(gpu_id)
global _is_fp8_available, _reason_for_no_fp8
if _is_fp8_available is None:
_is_fp8_available = True
# JAX doesn't provide the local GPU id.
for local_gpu_id in range(len(jax.local_devices())):
ret, msg = _check_fp8_support(local_gpu_id)
if ret is False:
_is_fp8_available = ret
_reason_for_no_fp8 = msg
break
return _is_fp8_available, _reason_for_no_fp8
def _format2dtypes(format_: Format):
if format_ == Format.E4M3:
return jnp.float8_e4m3fn, jnp.float8_e4m3fn
if format_ == Format.E5M2:
return jnp.float8_e5m2, jnp.float8_e5m2
if format_ == Format.HYBRID:
return jnp.float8_e4m3fn, jnp.float8_e5m2
return jnp.bfloat16, jnp.bfloat16
# fm32 is a custom dtype to specify the "add" rules as max operation.
# This is typically used in Pipeline Parallelism + "MiconBatching > 1",
# which is implemented via nn.scan. Without this custom dtype, nn.scan
# would sum gradients from all micro-batches, and this is not the expected
# behavior for FP8 meta. Instead, the summation of FP8 meta gradients should
# be "MAX".
FlaxFloatMeta32 = fp8_ops.fm32
class FP8MetaPackage:
"""
A container that contains all required meta data for FP8
"""
NUM_OF_META: int = 3
INPUT_IDX: int = 0
WEIGHT_IDX: int = 1
GRAD_IDX: int = 2
def __init__(
self,
input_amax: jnp.ndarray,
input_scale: jnp.ndarray,
weight_amax: jnp.ndarray,
weight_scale: jnp.ndarray,
grad_amax: jnp.ndarray,
grad_scale: jnp.ndarray,
) -> None:
self._amax_list = [None] * FP8MetaPackage.NUM_OF_META
self._scale_list = [None] * FP8MetaPackage.NUM_OF_META
self._amax_list[FP8MetaPackage.INPUT_IDX] = input_amax
self._scale_list[FP8MetaPackage.INPUT_IDX] = input_scale
self._amax_list[FP8MetaPackage.WEIGHT_IDX] = weight_amax
self._scale_list[FP8MetaPackage.WEIGHT_IDX] = weight_scale
self._amax_list[FP8MetaPackage.GRAD_IDX] = grad_amax
self._scale_list[FP8MetaPackage.GRAD_IDX] = grad_scale
@property
def amax_list(self) -> List[jnp.ndarray]:
"""
Get the amax list of this package.
"""
return self._amax_list
@property
def scale_list(self) -> List[jnp.ndarray]:
"""
Get the scale list of this package.
"""
return self._scale_list
@staticmethod
def update_amax_list(amax_list: List[jnp.ndarray]) -> jnp.ndarray:
"""
Update the amax history list
"""
updated_amax_list = [FP8Helper.update_amax_history(amax) for amax in amax_list]
return updated_amax_list
@staticmethod
def update_fp8_scale(
amax_list: List[jnp.ndarray], scale_list: List[jnp.ndarray], fp8_dtype_list: List[DType]
) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
"""
Get update scale and scale_inv list
"""
update_scale_list = []
update_scale_inv_list = []
for amax, scale, fp8_dtype in zip(amax_list, scale_list, fp8_dtype_list):
upadted_scale, updated_scale_inv = FP8Helper.update_fp8_scale(amax, scale, fp8_dtype)
update_scale_list.append(upadted_scale)
update_scale_inv_list.append(updated_scale_inv)
return update_scale_list, update_scale_inv_list
class AmaxComputeAlgo(Enum):
"""AmaxComputeAlgo."""
MAX = "max"
MOST_RECENT = "most_recent"
NVTE_FP8_COLLECTION_NAME = "fp8_metas"
class FP8Helper:
"""
FP8 helper to manage the FP8 meta
"""
INITIALIZED = False
MARGIN: float = 0.0
FP8_FORMAT: Format = Format.HYBRID
FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0]
BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1]
AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
FP8_COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
FP8_AMAX_NAME: str = "amax"
FP8_SCALE_NAME: str = "scale"
FP8_2X_ACC_FPROP: bool = False
FP8_2X_ACC_DGRAD: bool = True
FP8_2X_ACC_WGRAD: bool = True
@staticmethod
def is_fp8_enabled():
"""
Indicate if fp8 training is enable or not.
"""
return FP8Helper.INITIALIZED
@staticmethod
def initialize(
margin: float = 0.0,
fp8_format: Format = Format.HYBRID,
amax_history_len: int = 1,
amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX,
) -> None:
"""
Initialize the FP8 meta
"""
FP8Helper.INITIALIZED = True
FP8Helper.MARGIN = margin
FP8Helper.FP8_FORMAT = fp8_format
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = _format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.AMAX_HISTORY_LEN = amax_history_len
FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo
FP8Helper.FP8_2X_ACC_FPROP = False
FP8Helper.FP8_2X_ACC_DGRAD = True
FP8Helper.FP8_2X_ACC_WGRAD = True
@staticmethod
def finalize() -> None:
"""
FP8 helper finalize
"""
FP8Helper.INITIALIZED = False
FP8Helper.MARGIN = 0.0
FP8Helper.FP8_FORMAT = Format.HYBRID
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = _format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.AMAX_HISTORY_LEN = 1024
FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
@staticmethod
def update_collections(new: Collection, original: Collection) -> Collection:
"""
Update the collections
"""
assert isinstance(original, (dict, FrozenDict))
assert isinstance(new, (dict, FrozenDict))
frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
for key in new:
if key in frozen_original:
frozen_original, _ = frozen_original.pop(key)
new_coll = FrozenDict({**new, **frozen_original})
if not isinstance(original, FrozenDict):
new_coll = new_coll.unfreeze()
return new_coll
@staticmethod
def generate_fp8_meta_dtype_converter_pair(*args):
"""
Generate a pair of conversion fun in-between fm32 and fp32.
"""
def identical_fun(*metas):
return list(metas)
def fm32_to_fp32_fun(*metas):
for meta in metas:
assert meta.dtype == FlaxFloatMeta32
return [jax.lax.convert_element_type(meta, jnp.float32) for meta in metas]
def fp32_to_fm32_fun(*metas):
for meta in metas:
assert meta.dtype == jnp.float32
return [jax.lax.convert_element_type(meta, FlaxFloatMeta32) for meta in metas]
# Make functions to be a vaild JAX type
partial_identical_fun = jax.tree_util.Partial(identical_fun)
partial_fm32_to_fp32_fun = jax.tree_util.Partial(fm32_to_fp32_fun)
partial_fp32_to_fm32_fun = jax.tree_util.Partial(fp32_to_fm32_fun)
if len(args) < 1:
return partial_identical_fun, partial_identical_fun
original_dtype = args[0].dtype
for arg in args:
assert arg.dtype == original_dtype
if original_dtype == FlaxFloatMeta32:
return partial_fm32_to_fp32_fun, partial_fp32_to_fm32_fun
return partial_identical_fun, partial_identical_fun
@staticmethod
@jax.jit
def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray:
"""
Update the amax history
"""
updated_amax = jnp.roll(amax, -1, -1)
updated_amax = updated_amax.at[0].set(0)
return updated_amax
@staticmethod
@partial(jax.jit, static_argnums=(2,))
def update_fp8_scale(amax: jnp.ndarray, scale: jnp.ndarray, fp8_dtype: DType) -> jnp.ndarray:
"""
Calculate fp8 scale and scale_inv based on given amax.
"""
fp8_max = jnp.astype(jnp.finfo(fp8_dtype).max, jnp.float32)
if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(amax, axis=-1, keepdims=True)
else:
amax = amax[0:1]
sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
scale = sf
scale_inv = 1 / sf
return scale, scale_inv
@contextmanager
def fp8_autocast(
enabled: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
mesh_resource: Optional[MeshResource] = None,
) -> None:
r"""
Context manager for FP8 usage.
.. code-block:: python
mesh_shape = (4, 2)
dp_mesh_axis_name = 'data_parallel'
tp_mesh_axis_name = 'tensor_parallel'
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
rules = extend_logical_axis_rules(tuple())
transformer = TransformerLayer()
with partitioning.axis_rules(rules):
pjit(transformer.init, ...)(...)
.. note::
We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`,
and :attr:`amax_compute_algo` (with value 'max' and 'most_recent') in
recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling
will trigger an assertion.
Parameters
----------
enabled: bool, default = False
Whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None
Recipe used for FP8 training.
mesh_resource: MeshResource, default = None
Specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then no data or tensor parallelism will be used.
"""
if fp8_recipe is None:
fp8_recipe = DelayedScaling()
assert fp8_recipe.amax_compute_algo in [
"max",
"most_recent",
], "DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX."
assert (
fp8_recipe.scaling_factor_compute_algo is None
), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX."
assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX."
if mesh_resource is None:
mesh_resource = MeshResource()
try:
with global_shard_guard(mesh_resource):
if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available()
assert fp8_available, reason_for_no_fp8
amax_compute_algo = AmaxComputeAlgo.MOST_RECENT
if fp8_recipe.amax_compute_algo == "max":
amax_compute_algo = AmaxComputeAlgo.MAX
FP8Helper.initialize(
margin=fp8_recipe.margin,
fp8_format=fp8_recipe.fp8_format,
amax_history_len=fp8_recipe.amax_history_len,
amax_compute_algo=amax_compute_algo,
)
yield
finally:
FP8Helper.finalize()
# Function Wrappers
def update_collections(new: Collection, original: Collection) -> FrozenDict:
r"""
A helper to update Flax's Collection.
Collection = [dict, flax.core.frozen_dict.FrozenDict]
Parameters
----------
new: Collection
A collection that includes new data.
original: Collection
The base collection.
Returns
-------
outputs : Collection
The updated collection.
"""
return FP8Helper.update_collections(new, original)
def get_delayed_scaling():
r"""
Obtain an instance of DelayedScaling which is set via fp8_autocast.
.. note::
We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`
, and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in
recipe.DelayedScaling would be returned as the default values.
Returns
-------
delay_scaling : DelayedScaling
an instance of DelayedScaling which is set via fp8_autocast.
"""
amax_compute_algo = (
"max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent"
)
return DelayedScaling(
margin=int(FP8Helper.MARGIN),
fp8_format=FP8Helper.FP8_FORMAT,
amax_history_len=FP8Helper.AMAX_HISTORY_LEN,
amax_compute_algo=amax_compute_algo,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX layernorm modules""" """Layer normalization operations for Transformer Engine in JAX.
This module provides optimized layer normalization operations for transformer
architectures, including support for different normalization types and quantization.
It implements various normalization strategies like LayerNorm and RMSNorm, with
optional zero-centered gamma and epsilon parameters.
"""
from functools import partial from functools import partial
from typing import List, Tuple
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .dot import fp8_dot_impl, get_precision_of_fp8_dot
from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
from .quantize import (
ScaledTensor,
Quantizer,
)
def canonicalize_layernorm_type(x):
""" def canonicalize_norm_type(x):
Canonicalize the layernorm type """Convert normalization type string to canonical form.
Args:
x: Input normalization type string
Returns:
Canonicalized normalization type string
""" """
canonicalized = x.lower().strip().replace("-", "").replace("_", "") canonicalized = x.lower().strip().replace("-", "").replace("_", "")
assert canonicalized in ["layernorm", "rmsnorm"] assert canonicalized in ["layernorm", "rmsnorm"]
...@@ -25,365 +37,106 @@ def canonicalize_layernorm_type(x): ...@@ -25,365 +37,106 @@ def canonicalize_layernorm_type(x):
def layernorm( def layernorm(
inputs: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, beta: jnp.ndarray,
layernorm_type: str, norm_type: str,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
quantizer: Quantizer = None,
): ):
"""Apply layer normalization with optional quantization.
This function implements layer normalization with support for different
normalization types and optional quantization. It normalizes the input
tensor using the provided gamma and beta parameters.
Args:
x: Input tensor to normalize
gamma: Scale parameter for normalization
beta: Shift parameter for normalization
norm_type: Type of normalization to apply
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
quantizer: Optional quantizer for quantizing the output
Returns:
Normalized output tensor
""" """
LN/RMSNorm wrapper output = _layernorm(x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer)
Only support layernorm_type in ['layernorm', 'rmsnorm']
"""
output = _layernorm(
inputs,
gamma,
beta,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _layernorm( def _layernorm(x, gamma, beta, norm_type: str, zero_centered_gamma, epsilon, quantizer):
x, gamma, beta, layernorm_type: str, zero_centered_gamma: bool = False, epsilon: float = 1e-6 """Internal implementation of layer normalization with custom VJP.
):
output, _ = _layernorm_fwd_rule(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon) This function implements the core layer normalization logic with support
return output for custom vector-Jacobian product (VJP) for automatic differentiation.
Args:
def _layernorm_fwd_rule( x: Input tensor
x, gamma, beta, layernorm_type: str, zero_centered_gamma: bool = False, epsilon: float = 1e-6 gamma: Scale parameter
): beta: Shift parameter
layernorm_type = canonicalize_layernorm_type(layernorm_type) norm_type: Type of normalization
if layernorm_type == "layernorm": zero_centered_gamma: Whether to use zero-centered gamma
output, mu, rsigma = tex.layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon) epsilon: Small constant for numerical stability
elif layernorm_type == "rmsnorm": quantizer: Optional quantizer
assert (
not zero_centered_gamma Returns:
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'" Normalized tensor
output, rsigma = tex.rmsnorm_fwd(x, gamma, epsilon)
mu = None
else:
raise ValueError(f"{layernorm_type=} is not supported.")
return output, (x, mu, rsigma, gamma, beta)
def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
x, mu, rsigma, gamma, beta = ctx
if layernorm_type == "layernorm":
dx, dgamma, dbeta = tex.layernorm_bwd(
dz, x, mu, rsigma, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
elif layernorm_type == "rmsnorm":
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
dx, dgamma = tex.rmsnorm_bwd(dz, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
else:
raise ValueError(f"{layernorm_type=} is not supported.")
return dx, dgamma, dbeta
_layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)
def layernorm_fp8_dot(
x: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
fp8_meta_pkg: FP8MetaPackage,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
layernorm_input_axes: Tuple[
str, ...
] = None, # The logic axes of sharding constraint to the layernorm input.
dot_input_axes: Tuple[
str, ...
] = None, # The logic axes of sharding constraint to the dot input.
) -> jnp.ndarray:
""" """
Layernorm + FP8 GEMM output, _ = _layernorm_fwd_rule(
""" x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer
amax_list = fp8_meta_pkg.amax_list
scale_list = fp8_meta_pkg.scale_list
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
output = _layernorm_fp8_dot(
x,
kernel,
gamma,
beta,
amax_list,
scale_list,
layernorm_type,
fwd_dtype,
bwd_dtype,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12)) def _layernorm_fwd_rule(x, gamma, beta, norm_type: str, zero_centered_gamma, epsilon, quantizer):
def _layernorm_fp8_dot( """Forward pass rule for layer normalization.
x: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
amax_list: List[jnp.ndarray],
scale_list: List[jnp.ndarray],
layernorm_type: str,
fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype,
zero_centered_gamma: bool,
epsilon: float,
layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...],
):
output, _ = _layernorm_fp8_dot_fwd_rule(
x,
kernel,
gamma,
beta,
amax_list,
scale_list,
layernorm_type,
fwd_dtype,
bwd_dtype,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes,
)
return output
def _layernorm_fp8_dot_fwd_rule(
x,
kernel,
gamma,
beta,
amax_list,
scale_list,
layernorm_type,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes,
):
x_contracting_dims = (len(x.shape) - 1,) Args:
k_contracting_dims = (0,) x: Input tensor
assert x.shape[-1] == kernel.shape[0] gamma: Scale parameter
beta: Shift parameter
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
quantizer: Optional quantizer
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( Returns:
*amax_list, *scale_list Tuple of (output, context) for backward pass
) """
amax_list = maybe_fm32_to_fp32(*amax_list)
scale_list = maybe_fm32_to_fp32(*scale_list)
fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale(
amax_list, scale_list, fp8_dtype_list
)
amax_list = FP8MetaPackage.update_amax_list(amax_list)
x_amax = amax_list[FP8MetaPackage.INPUT_IDX][0:1]
x_scale = scale_list[FP8MetaPackage.INPUT_IDX]
x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == "layernorm":
ln_out, mu, rsigma, updated_x_amax = tex.layernorm_fwd_fp8(
x,
gamma,
beta,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
else:
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(
x, gamma, x_amax, x_scale, x_scale_inv, out_dtype=fwd_dtype, epsilon=epsilon
)
mu = None
assert x.shape == ln_out.shape
kernel_amax = amax_list[FP8MetaPackage.WEIGHT_IDX][0:1]
kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX]
kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
# Kernel in (hidden_in, hidden_out...)
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel, updated_kernel_amax = tex.cast_fp8(
kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype
)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_input_axes)
# (batch..., hidden_in) x (hidden_in, hidden_out...)
output = fp8_dot_impl(
ln_out,
casted_kernel,
x_scale_inv,
kernel_scale_inv,
x.dtype,
(x_contracting_dims, k_contracting_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
)
ctx = ( norm_type = canonicalize_norm_type(norm_type)
ln_out, output, mu, rsigma = tex.normalization_fwd(
casted_kernel, x, gamma, beta, zero_centered_gamma, epsilon, norm_type, quantizer
amax_list,
scale_list,
scale_inv_list,
updated_x_amax,
updated_kernel_amax,
x.shape,
kernel.shape,
mu,
rsigma,
x,
gamma,
beta,
x_contracting_dims,
k_contracting_dims,
maybe_fp32_to_fm32,
) )
if isinstance(output, ScaledTensor):
output = output.dequantize()
return output, ctx return output, (x, mu, rsigma, gamma, beta, quantizer)
def _layernorm_fp8_dot_bwd_rule( def _layernorm_bwd_rule(norm_type, zero_centered_gamma, epsilon, ctx, dz):
layernorm_type, """Backward pass rule for layer normalization.
fwd_dtype, # pylint: disable=unused-argument
bwd_dtype,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
ctx,
grad,
):
(
ln_out_,
casted_kernel,
amax_list,
scale_list,
scale_inv_list,
updated_x_amax,
updated_kernel_amax,
x_shape,
kernel_shape,
mu,
rsigma,
x,
gamma,
beta,
x_contracting_dims,
k_contracting_dims,
maybe_fp32_to_fm32,
) = ctx
ln_out_t = tex.transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1)
grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1]
grad_scale = scale_list[FP8MetaPackage.GRAD_IDX]
grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_IDX]
casted_grad, casted_grad_t, updated_grad_amax = tex.cast_transpose(
grad,
grad_amax,
grad_scale,
grad_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=min(x_contracting_dims),
)
xt_constracting_dim = tuple(range(len(x_contracting_dims), len(x_shape))) Args:
gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim)) norm_type: Type of normalization
x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] zero_centered_gamma: Whether to use zero-centered gamma
wgrad = fp8_dot_impl( epsilon: Small constant for numerical stability
ln_out_t, ctx: Context from forward pass
casted_grad_t, dz: Gradient from upstream
x_scale_inv,
grad_scale_inv,
grad.dtype,
(xt_constracting_dim, gt_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
)
g_for_dgrad_constracting_dim = tuple( Returns:
range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim) Tuple of gradients with respect to inputs
) """
k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape))) x, mu, rsigma, gamma, beta, quantizer = ctx
kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
dgrad = fp8_dot_impl(
casted_grad,
casted_kernel,
grad_scale_inv,
kernel_scale_inv,
grad.dtype,
(g_for_dgrad_constracting_dim, k_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) dx, dgamma, dbeta = tex.normalization_bwd(
if layernorm_type == "layernorm": dz, x, mu, rsigma, gamma, beta, zero_centered_gamma, epsilon, norm_type
dx, dgamma, dbeta = tex.layernorm_bwd(
dgrad,
x,
mu,
rsigma,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
else:
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
dx, dgamma = tex.rmsnorm_bwd(dgrad, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
amax_list[FP8MetaPackage.INPUT_IDX] = (
amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0])
)
amax_list[FP8MetaPackage.WEIGHT_IDX] = (
amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0])
)
amax_list[FP8MetaPackage.GRAD_IDX] = (
amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
) )
return dx, dgamma, dbeta, quantizer
amax_list = maybe_fp32_to_fm32(*amax_list)
scale_list = maybe_fp32_to_fm32(*scale_list)
return dx, wgrad, dgamma, dbeta, amax_list, scale_list _layernorm.defvjp(_layernorm_fwd_rule, _layernorm_bwd_rule)
_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd_rule, _layernorm_fp8_dot_bwd_rule)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused Layer normalization and dense layer transformation operations for Transformer Engine in JAX.
This module provides optimized implementations of layer normalization followed by
dense layer transformation (GEMM) operations, which are commonly used in transformer
architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints.
"""
from functools import partial
from typing import Tuple
import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize import (
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
)
def layernorm_dense(
x: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
bias: jnp.ndarray = None,
norm_type: str = "layernorm",
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
# The logic axes of sharding constraint to the layernorm input.
layernorm_input_axes: Tuple[str, ...] = None,
# The logic axes of sharding constraint to the dot input.
dot_input_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation.
This function implements the following sequence of operations:
1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
2. Linear transformation: y = x * kernel + bias
Args:
x: Input tensor with shape [batch..., hidden_in]
kernel: Weight matrix with shape [hidden_in, hidden_out]
gamma: Scale parameter for normalization with shape [hidden_in]
beta: Bias parameter for normalization with shape [hidden_in]
bias: Optional bias term for dense layer transformation with shape [hidden_out]
norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
quantizer_set: Set of quantizers for different tensor types
Returns:
Output tensor with shape [batch..., hidden_out]
Note:
- For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
must be False
- The function supports automatic differentiation through JAX's custom VJP
- Quantization is applied to both the normalized input and kernel
"""
output = _layernorm_dense(
x,
kernel,
gamma,
beta,
bias,
norm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes,
quantizer_set,
)
return output
@partial(
jax.custom_vjp,
nondiff_argnums=(
5,
6,
7,
8,
9,
),
)
def _layernorm_dense(
x: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
bias: jnp.ndarray,
norm_type: str,
zero_centered_gamma: bool,
epsilon: float,
layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...],
quantizer_set,
):
"""Internal implementation of layernorm_dense with custom VJP.
This function implements the forward pass of layernorm_dense with support for
automatic differentiation. It handles the normalization and dense layer transformation
operations, including quantization and sharding constraints.
Args:
x: Input tensor
kernel: Weight matrix
gamma: Scale parameter for normalization
beta: Bias parameter for normalization
bias: Optional bias term
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding
quantizer_set: Set of quantizers
Returns:
Output tensor from the combined operations
"""
output, _ = _layernorm_dense_fwd_rule(
x,
kernel,
gamma,
beta,
bias,
norm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes,
quantizer_set,
)
return output
def _layernorm_dense_fwd_rule(
x,
kernel,
gamma,
beta,
bias,
norm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes,
quantizer_set,
):
"""Forward pass rule for layernorm_dense.
Implements the forward pass computation including:
1. Layer normalization with quantization
2. Matrix multiplication with quantized kernel
3. Optional bias addition
4. Sharding constraints
Returns:
Tuple of (output, context) for automatic differentiation
"""
x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0]
assert len(kernel.shape) == 2 # Otherwise need to merge dims in quantize
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
casted_ln_out, mu, rsigma = tex.normalization_fwd(
x,
gamma,
beta,
zero_centered_gamma,
epsilon,
norm_type,
quantizer_set.x,
)
# Kernel in (hidden_in, hidden_out...)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...)
output = tex.gemm(
casted_ln_out.get_rowwise_tensor(),
casted_kernel.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
)
use_bias = bias is not None
if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
ctx = (
casted_ln_out.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None,
casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None,
x.shape,
kernel.shape,
mu,
rsigma,
x,
gamma,
beta,
x_contracting_dims,
k_contracting_dims,
use_bias,
quantizer_set,
)
return output, ctx
def _layernorm_dense_bwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
ctx,
grad,
):
"""Backward pass rule for layernorm_dense.
Implements the backward pass computation including:
1. Gradient computation for matrix multiplication
2. Gradient computation for layer normalization
3. Gradient computation for bias terms
4. Proper handling of quantization
Returns:
Tuple of gradients for all input parameters
"""
(
colwise_casted_ln_out,
rowwise_casted_kernel,
x_shape,
kernel_shape,
mu,
rsigma,
x,
gamma,
beta,
x_contracting_dims_in_fwd,
k_contracting_dims_in_fwd,
use_bias,
quantizer_set,
) = ctx
grad = with_sharding_constraint_by_logical_axes(grad, dot_input_axes)
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
)
# k_non_contracting_dims
k_constracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in k_contracting_dims_in_fwd
)
# NT GEMM
dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel,
(g_constracting_dim, k_constracting_dim),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
g_constracting_dim = x_constracting_dim = tuple(
range(0, len(x_shape) - len(x_contracting_dims_in_fwd))
)
# TN GEMM
wgrad = tex.gemm(
colwise_casted_ln_out,
casted_grad.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim),
)
dx, dgamma, dbeta = tex.normalization_bwd(
dgrad,
x,
mu,
rsigma,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
norm_type=norm_type,
)
return dx, wgrad, dgamma, dbeta, dbias, quantizer_set
_layernorm_dense.defvjp(_layernorm_dense_fwd_rule, _layernorm_dense_bwd_rule)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX MLP modules""" """Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.
This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias
The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
from typing import List, Tuple, Sequence, Union, Callable from typing import List, Tuple, Sequence, Union, Callable
from functools import partial from functools import partial
...@@ -11,92 +21,81 @@ import jax.numpy as jnp ...@@ -11,92 +21,81 @@ import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize from .layernorm import canonicalize_norm_type
from .layernorm import canonicalize_layernorm_type from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set
from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
"""
Activation Unit
"""
if len(activation_type) > 1:
assert x.shape[-2] == 2 # Linear + GeLU
output = _activation_lu(x, activation_type)
return output
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
_output, _ = _activation_lu_fwd_rule(x, activation_type)
return _output
def _activation_lu_fwd_rule(x, activation_type):
fwd_output = tex.act_lu(x, activation_type)
return fwd_output, (x,)
def _activation_lu_bwd_rule(activation_type, ctx, g):
(x,) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx,)
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
def layernorm_mlp(
def fused_layernorm_fp8_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, beta: jnp.ndarray,
kernels: List[jnp.ndarray], kernels: List[jnp.ndarray],
biases: List[jnp.ndarray], biases: List[jnp.ndarray],
fp8_meta_pkgs: List[FP8MetaPackage], norm_type: str,
layernorm_type: str,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
layernorm_input_axes: Tuple[str, ...] = None, norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = "ffn1", ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2", ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
use_bias: bool = True, quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by MLP block.
This function implements the following sequence of operations:
1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
2. First dense layer transformation: y1 = x * kernel1 + bias1
3. Activation function: y2 = activation(y1)
4. Second dense layer transformation: y3 = y2 * kernel2 + bias2
Args:
x: Input tensor with shape [batch..., hidden_in]
gamma: Scale parameter for normalization with shape [hidden_in]
beta: Bias parameter for normalization with shape [hidden_in]
kernels: List of two weight matrices:
- kernel1: [hidden_in, intermediate]
- kernel2: [intermediate, hidden_in]
biases: List of two bias terms:
- bias1: [intermediate]
- bias2: [hidden_in]
norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization
norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns:
Output tensor with shape [batch..., hidden_in]
Note:
- For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
must be False
- The function supports automatic differentiation through JAX's custom VJP
- Quantization is applied to both dense layer transformations
- Checkpointing is applied to both feed-forward networks for memory efficiency
""" """
Layernorm + GEMM1 + bias + activation + GEMM2 + bias
"""
assert len(kernels) == 2 assert len(kernels) == 2
assert len(fp8_meta_pkgs) == len(kernels)
kernel_1 = kernels[0] kernel_1 = kernels[0]
kernel_2 = kernels[1] kernel_2 = kernels[1]
bias_1 = biases[0] bias_1 = biases[0]
bias_2 = biases[1] bias_2 = biases[1]
amax_list_1 = fp8_meta_pkgs[0].amax_list
amax_list_2 = fp8_meta_pkgs[1].amax_list
scale_list_1 = fp8_meta_pkgs[0].scale_list
scale_list_2 = fp8_meta_pkgs[1].scale_list
fwd_dtype = FP8Helper.FWD_DTYPE norm_type = canonicalize_norm_type(norm_type)
bwd_dtype = FP8Helper.BWD_DTYPE if norm_type == "rmsnorm":
assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == "rmsnorm":
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert ( assert (
not zero_centered_gamma not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'" ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
output = _fused_layernorm_fp8_mlp( output = _layernorm_mlp(
x, x,
gamma, gamma,
beta, beta,
...@@ -104,28 +103,22 @@ def fused_layernorm_fp8_mlp( ...@@ -104,28 +103,22 @@ def fused_layernorm_fp8_mlp(
kernel_2, kernel_2,
bias_1, bias_1,
bias_2, bias_2,
amax_list_1, norm_type,
amax_list_2,
scale_list_1,
scale_list_2,
fwd_dtype,
bwd_dtype,
layernorm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
layernorm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
use_bias, quantizer_sets,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15))
def _fused_layernorm_fp8_mlp( def _layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, beta: jnp.ndarray,
...@@ -133,24 +126,46 @@ def _fused_layernorm_fp8_mlp( ...@@ -133,24 +126,46 @@ def _fused_layernorm_fp8_mlp(
kernel_2: jnp.ndarray, kernel_2: jnp.ndarray,
bias_1: jnp.ndarray, bias_1: jnp.ndarray,
bias_2: jnp.ndarray, bias_2: jnp.ndarray,
amax_list_1: List[jnp.ndarray], norm_type: str,
amax_list_2: List[jnp.ndarray],
scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray],
fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype,
layernorm_type: str,
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
layernorm_input_axes: Tuple[str, ...], norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn1_ckpt_name: str,
ffn2_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
use_bias: bool, quantizer_sets,
): ):
output, _ = _fused_layernorm_fp8_mlp_fwd_rule( """Internal implementation of layernorm_mlp with custom VJP.
This function implements the forward pass of layernorm_mlp with support for
automatic differentiation. It handles the normalization, dense layer transformations,
activation, and quantization operations.
Args:
x: Input tensor
gamma: Scale parameter for normalization
beta: Bias parameter for normalization
kernel_1: First weight matrix
kernel_2: Second weight matrix
bias_1: First bias term
bias_2: Second bias term
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
norm_input_axes: Logical axes for layernorm sharding
dot_1_input_axes: Logical axes for first matrix multiplication sharding
dot_2_input_axes: Logical axes for second matrix multiplication sharding
ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s)
quantizer_sets: Tuple of quantizer sets
Returns:
Output tensor from the combined operations
"""
output, _ = _layernorm_mlp_fwd_rule(
x, x,
gamma, gamma,
beta, beta,
...@@ -158,27 +173,21 @@ def _fused_layernorm_fp8_mlp( ...@@ -158,27 +173,21 @@ def _fused_layernorm_fp8_mlp(
kernel_2, kernel_2,
bias_1, bias_1,
bias_2, bias_2,
amax_list_1, norm_type,
amax_list_2,
scale_list_1,
scale_list_2,
fwd_dtype,
bwd_dtype,
layernorm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
layernorm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
use_bias, quantizer_sets,
) )
return output return output
def _fused_layernorm_fp8_mlp_fwd_rule( def _layernorm_mlp_fwd_rule(
x, x,
gamma, gamma,
beta, beta,
...@@ -186,444 +195,257 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -186,444 +195,257 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
kernel_2, kernel_2,
bias_1, bias_1,
bias_2, bias_2,
amax_list_1, norm_type,
amax_list_2,
scale_list_1,
scale_list_2,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
layernorm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
layernorm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
use_bias, quantizer_sets,
): ):
"""Forward pass rule for layernorm_mlp.
Implements the forward pass computation including:
1. Layer normalization with quantization
2. First matrix multiplication with quantized kernel
3. Activation function application
4. Second matrix multiplication with quantized kernel
5. Optional bias additions
6. Sharding constraints
7. Checkpointing for memory efficiency
Returns:
Tuple of (output, context) for automatic differentiation
"""
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
# x should be in shape of (batch..., hidden) # x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out) # Kernel_1 should be in shape of (hidden_in, activation_len * intermediate)
# Kernel_2 should be in shape of (Hidden_in, Hidden_out) # Kernel_2 should be in shape of (intermediate, hidden_in)
assert len(kernel_1.shape) == 3 assert len(kernel_1.shape) == 2
assert kernel_1.shape[-2] == len(activation_type)
assert len(kernel_2.shape) == 2 assert len(kernel_2.shape) == 2
assert kernel_1.shape[1] == kernel_2.shape[0] * len(activation_type)
x_contracting_dims = (len(x.shape) - 1,) x_contracting_dims = (len(x.shape) - 1,)
xt_batch_dims = tuple(range(1, x.ndim)) k_contracting_dims = (0,)
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0] assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
assert kernel_1.shape[-1] == kernel_2.shape[0] assert kernel_1.shape[1] == len(activation_type) * kernel_2.shape[0]
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( use_bias_1 = bias_1 is not None
*amax_list_1, *scale_list_1, *amax_list_2, *scale_list_2 use_bias_2 = bias_1 is not None
)
amax_list_1 = maybe_fm32_to_fp32(*amax_list_1) x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)
scale_list_1 = maybe_fm32_to_fp32(*scale_list_1)
amax_list_2 = maybe_fm32_to_fp32(*amax_list_2) casted_ln_out, mu, rsigma = tex.normalization_fwd(
scale_list_2 = maybe_fm32_to_fp32(*scale_list_2) x,
gamma,
fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype] beta,
scale_list_1, scale_inv_list_1 = FP8MetaPackage.update_fp8_scale( zero_centered_gamma,
amax_list_1, scale_list_1, fp8_dtype_list epsilon,
) norm_type,
amax_list_1 = FP8MetaPackage.update_amax_list(amax_list_1) quantizer=ffn1_quantizer_set.x,
scale_list_2, scale_inv_list_2 = FP8MetaPackage.update_fp8_scale(
amax_list_2, scale_list_2, fp8_dtype_list
)
amax_list_2 = FP8MetaPackage.update_amax_list(amax_list_2)
x_amax = amax_list_1[FP8MetaPackage.INPUT_IDX][0:1]
x_scale = scale_list_1[FP8MetaPackage.INPUT_IDX]
x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == "layernorm":
ln_out, mu, rsigma, updated_x_amax = tex.layernorm_fwd_fp8(
x,
gamma,
beta,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
else:
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(
x, gamma, x_amax, x_scale, x_scale_inv, out_dtype=fwd_dtype, epsilon=epsilon
)
mu = None
assert x.shape == ln_out.shape
kernel_1_amax = amax_list_1[FP8MetaPackage.WEIGHT_IDX][0:1]
kernel_1_scale = scale_list_1[FP8MetaPackage.WEIGHT_IDX]
kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX]
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_1, updated_kernel_1_amax = tex.cast_fp8(
kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype
) )
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes) casted_kernel_1 = tex.quantize(kernel_1, quantizer=ffn1_quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out) # (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = fp8_dot_impl( dot_1_output = tex.gemm(
ln_out, casted_ln_out.get_rowwise_tensor(),
casted_kernel_1, casted_kernel_1.get_colwise_tensor(),
x_scale_inv, (x_contracting_dims, k_contracting_dims),
kernel_1_scale_inv,
x.dtype,
(x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
) )
if use_bias: if use_bias_1:
bias_1_shape = bias_1.shape bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
dot_1_output += jnp.reshape(bias_1, bias_1_new_shape) dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
else:
bias_1_shape = None
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
activation_lu_out_amax = amax_list_2[FP8MetaPackage.INPUT_IDX][0:1] dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
activation_lu_out_scale = scale_list_2[FP8MetaPackage.INPUT_IDX]
activation_lu_out_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
# (batch..., hidden_in) -> (batch..., hidden) # (batch..., hidden_in) -> (batch..., hidden)
casted_activation_lu_out, updated_activation_lu_amax = tex.act_lu_fp8( casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x)
dot_1_output,
activation_lu_out_amax,
activation_lu_out_scale,
activation_lu_out_scale_inv,
fwd_dtype,
activation_type,
)
casted_activation_lu_out = with_sharding_constraint_by_logical_axes( casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
casted_activation_lu_out, dot_2_input_axes
)
kernel_2_scale = scale_list_2[FP8MetaPackage.WEIGHT_IDX] casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel)
kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)
# NN GEMM
# (batch..., hidden_in) x (hidden_out, hidden_in) # (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = fp8_dot_impl( dot_2_output = tex.gemm(
casted_activation_lu_out, casted_act_out.get_rowwise_tensor(),
casted_kernel_2, casted_kernel_2.get_colwise_tensor(),
activation_lu_out_scale_inv, (x_contracting_dims, k_contracting_dims),
kernel_2_scale_inv,
x.dtype,
(x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
) )
if use_bias: if use_bias_2:
bias_2_shape = bias_2.shape bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
else:
bias_2_shape = None
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = ( ctx = (
x, x,
ln_out,
mu, mu,
rsigma, rsigma,
gamma, gamma,
beta, beta,
casted_ln_out.get_colwise_tensor(),
casted_kernel_1.get_rowwise_tensor(),
dot_1_output, dot_1_output,
casted_activation_lu_out, casted_act_out.get_colwise_tensor(),
casted_kernel_1, casted_kernel_2.get_rowwise_tensor(),
casted_kernel_2,
amax_list_1,
amax_list_2,
scale_list_1,
scale_list_2,
scale_inv_list_1,
scale_inv_list_2,
updated_x_amax,
updated_activation_lu_amax,
updated_kernel_1_amax,
updated_kernel_2_amax,
x_contracting_dims, x_contracting_dims,
xt_batch_dims, k_contracting_dims,
bias_1_shape, kernel_1.shape,
bias_2_shape, kernel_2.shape,
maybe_fp32_to_fm32, use_bias_1,
use_bias_2,
quantizer_sets,
) )
return dot_2_output, ctx return dot_2_output, ctx
def _fused_layernorm_fp8_mlp_bwd_rule( def _layernorm_mlp_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument norm_type,
bwd_dtype,
layernorm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
layernorm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument ffn1_ckpt_name, # pylint: disable=unused-argument
ffn2_ckpt_name, # pylint: disable=unused-argument ffn2_ckpt_name, # pylint: disable=unused-argument
activation_type, activation_type,
use_bias,
ctx, ctx,
grad, grad,
): ):
"""Backward pass rule for layernorm_mlp.
Implements the backward pass computation including:
1. Gradient computation for second matrix multiplication
2. Gradient computation for activation function
3. Gradient computation for first matrix multiplication
4. Gradient computation for layer normalization
5. Gradient computation for bias terms
6. Proper handling of quantization
Returns:
Tuple of gradients for all input parameters
"""
( (
x, x,
ln_out,
mu, mu,
rsigma, rsigma,
gamma, gamma,
beta, beta,
colwise_casted_ln_out,
rowwise_casted_kernel_1,
dot_1_output, dot_1_output,
casted_activation_lu_out, colwise_casted_act_out,
casted_kernel_1, rowwise_casted_kernel_2,
casted_kernel_2, x_contracting_dims_in_fwd,
amax_list_1, k_contracting_dims_in_fwd,
amax_list_2, kernel_1_shape,
scale_list_1, kernel_2_shape,
scale_list_2, use_bias_1,
scale_inv_list_1, use_bias_2,
scale_inv_list_2, quantizer_sets,
updated_x_amax,
updated_activation_lu_amax,
updated_kernel_1_amax,
updated_kernel_2_amax,
x_contracting_dims,
xt_batch_dims,
bias_1_shape,
bias_2_shape,
maybe_fp32_to_fm32,
) = ctx ) = ctx
grad_amax = amax_list_2[FP8MetaPackage.GRAD_IDX][0:1] ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
grad_scale = scale_list_2[FP8MetaPackage.GRAD_IDX]
grad_scale_inv = scale_inv_list_2[FP8MetaPackage.GRAD_IDX]
# Since the sharding of outputs should be the same as dot_1's input # Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
if use_bias:
casted_grad, casted_grad_t, dbias_2, updated_grad_amax = tex.dbias_cast_transpose( casted_grad, dbias_2 = tex.quantize_dbias(
grad, grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad
grad_amax,
grad_scale,
grad_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1,
)
dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
else:
casted_grad, casted_grad_t, updated_grad_amax = tex.cast_transpose(
grad,
grad_amax,
grad_scale,
grad_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1,
)
dbias_2 = None
casted_activation_lu_out_t = tex.transpose(
casted_activation_lu_out, static_axis_boundary=-1, transpose_axis_boundary=-1
) )
# (hidden, batch...,) x (hidden, batch...) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
gemm2_x_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX] g_constracting_dim_2 = tuple(
wgrad_2 = fp8_dot_impl( range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
casted_activation_lu_out_t, )
casted_grad_t, # k_non_contracting_dims
gemm2_x_scale_inv, k_constracting_dim_2 = tuple(
grad_scale_inv, dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
grad.dtype,
(xt_batch_dims, xt_batch_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
) )
# NT GEMM
# (batch..., hidden_out) x (hidden_in, hidden_out) # (batch..., hidden_out) x (hidden_in, hidden_out)
kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX] dgrad_2 = tex.gemm(
dgrad_2 = fp8_dot_impl( casted_grad.get_rowwise_tensor(),
casted_grad, rowwise_casted_kernel_2,
casted_kernel_2, (g_constracting_dim_2, k_constracting_dim_2),
grad_scale_inv,
kernel_2_scale_inv,
grad.dtype,
(x_contracting_dims, (1,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
) )
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
dactivation_lu_amax = amax_list_1[FP8MetaPackage.GRAD_IDX][0:1] x_constracting_dim = g_constracting_dim = tuple(
dactivation_lu_scale = scale_list_1[FP8MetaPackage.GRAD_IDX] range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
dactivation_lu_scale_inv = scale_inv_list_1[FP8MetaPackage.GRAD_IDX]
if len(activation_type) > 1: # if gated
if use_bias:
dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type)
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = (
tex.dbias_cast_transpose(
dactivation_lu,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-2,
)
)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else:
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = (
tex.dgated_act_lu_cast_transpose(
dgrad_2,
dot_1_output,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
activation_type=activation_type,
)
)
dbias_1 = None
else:
if use_bias:
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = (
tex.dact_lu_dbias_cast_transpose(
dgrad_2,
dot_1_output,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
activation_type=activation_type,
)
)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else:
dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type)
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = (
tex.cast_transpose(
dactivation_lu,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-2,
)
)
dbias_1 = None
ln_out_t = tex.transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
# (hidden, batch...) x (hidden, batch...)
gemm1_x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims)
wgrad_1 = fp8_dot_impl(
ln_out_t,
casted_dactivation_lu_t,
gemm1_x_scale_inv,
dactivation_lu_scale_inv,
grad.dtype,
(xt_batch_dims, xt_batch_dims_2),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
) )
x_contracting_dims = ( # TN GEMM
(min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims), # (hidden, batch...,) x (hidden, batch...)
(1, 2), wgrad_2 = tex.gemm(
) colwise_casted_act_out,
kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX] casted_grad.get_colwise_tensor(),
dgrad_1 = fp8_dot_impl( (x_constracting_dim, g_constracting_dim),
casted_dactivation_lu,
casted_kernel_1,
dactivation_lu_scale_inv,
kernel_1_scale_inv,
grad.dtype,
x_contracting_dims,
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
) )
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes) casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
dgrad_2,
if layernorm_type == "layernorm": dot_1_output,
dx, dgamma, dbeta = tex.layernorm_bwd( activation_type=activation_type,
dgrad_1, is_dbias=use_bias_1,
x, quantizer=ffn2_quantizer_set.dgrad,
mu,
rsigma,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
else:
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
dx, dgamma = tex.rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
amax_list_1[FP8MetaPackage.INPUT_IDX] = (
amax_list_1[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0])
)
amax_list_1[FP8MetaPackage.WEIGHT_IDX] = (
amax_list_1[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_1_amax[0])
) )
amax_list_1[FP8MetaPackage.GRAD_IDX] = (
amax_list_1[FP8MetaPackage.GRAD_IDX].at[0].set(updated_dactivation_lu_amax[0]) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_1 = tuple(
range(dgrad_2.ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dgrad_2.ndim)
) )
amax_list_2[FP8MetaPackage.INPUT_IDX] = ( # k_non_contracting_dims
amax_list_2[FP8MetaPackage.INPUT_IDX].at[0].set(updated_activation_lu_amax[0]) k_constracting_dim_1 = tuple(
dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
) )
amax_list_2[FP8MetaPackage.WEIGHT_IDX] = (
amax_list_2[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_2_amax) # NT GEMM
dgrad_1 = tex.gemm(
casted_dact_out.get_rowwise_tensor(),
rowwise_casted_kernel_1,
(g_constracting_dim_1, k_constracting_dim_1),
) )
amax_list_2[FP8MetaPackage.GRAD_IDX] = (
amax_list_2[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0]) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, norm_input_axes)
# TN GEMM
# (hidden, batch...) x (hidden, batch...)
wgrad_1 = tex.gemm(
colwise_casted_ln_out,
casted_dact_out.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim),
) )
amax_list_1 = maybe_fp32_to_fm32(*amax_list_1) dx, dgamma, dbeta = tex.normalization_bwd(
scale_list_1 = maybe_fp32_to_fm32(*scale_list_1) dgrad_1,
amax_list_2 = maybe_fp32_to_fm32(*amax_list_2) x,
scale_list_2 = maybe_fp32_to_fm32(*scale_list_2) mu,
rsigma,
return ( gamma,
dx, beta,
dgamma, zero_centered_gamma=zero_centered_gamma,
dbeta, epsilon=epsilon,
wgrad_1, norm_type=norm_type,
wgrad_2,
dbias_1,
dbias_2,
amax_list_1,
amax_list_2,
scale_list_1,
scale_list_2,
) )
return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)
_fused_layernorm_fp8_mlp.defvjp( _layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)
_fused_layernorm_fp8_mlp_fwd_rule, _fused_layernorm_fp8_mlp_bwd_rule
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Python interface for quantization helpers.
This module provides a high-level interface for tensor quantization in JAX,
including support for various scaling modes and quantization strategies.
It exports all the necessary classes and functions from the underlying
implementation modules.
"""
from .tensor import *
from .quantizer import *
from .dequantizer import *
from .scaling_modes import *
from .metadata import *
from .helper import *
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Dequantization utilities for TE/JAX.
This module provides utilities for dequantizing tensors that have been quantized
using various scaling modes, including delayed scaling and block scaling.
"""
import jax
import jax.numpy as jnp
from .scaling_modes import ScalingMode
__all__ = ["Dequantizer"]
class Dequantizer:
"""Encapsulation class for dequantization helpers.
This class provides static methods for dequantizing tensors that have been
quantized using different scaling modes. It supports both delayed scaling
and block scaling modes.
"""
@staticmethod
def _dq_func_tensor_scaling(scaled_tensor):
"""Dequantize a tensor using delayed scaling.
This function dequantizes a tensor that was quantized using delayed scaling
by multiplying the quantized data with the inverse scaling factor.
Args:
scaled_tensor: The quantized tensor to dequantize
Returns:
The dequantized tensor in the specified data type
"""
return jnp.asarray(
scaled_tensor.data.astype(jnp.float32) * scaled_tensor.scale_inv.astype(jnp.float32),
scaled_tensor.dq_dtype,
)
@staticmethod
def _dq_func_block_scaling(scaled_tensor):
"""Dequantize a tensor using block scaling.
This function dequantizes a tensor that was quantized using block scaling
by applying the inverse scaling factor to each block of data.
Args:
scaled_tensor: The quantized tensor to dequantize
Returns:
The dequantized tensor in the specified data type
"""
data = scaled_tensor.data.astype(jnp.float32)
data_shape = data.shape
scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32)
scale_shape = scaled_tensor.scaling_mode.get_scale_shape(
scaled_tensor.data.shape, scaled_tensor.is_colwise, is_padded=False
)
scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding
data = data.reshape(
*data_shape[:-2],
scale_shape[-2],
int(data_shape[-2] / scale_shape[-2]),
scale_shape[-1],
int(data_shape[-1] / scale_shape[-1]),
)
scale = jnp.expand_dims(scale, axis=(-1, -3))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape(
data_shape
)
funcs = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling,
ScalingMode.NVTE_MXFP8_1D_SCALING: _dq_func_block_scaling,
}
@staticmethod
def dequantize(scaled_tensor):
"""Dequantize a scaled tensor using the appropriate scaling mode.
This method selects the appropriate dequantization function based on the
scaling mode used for quantization and applies it to the tensor.
Args:
scaled_tensor: The quantized tensor to dequantize
Returns:
The dequantized tensor in the specified data type
"""
dq_func = Dequantizer.funcs[scaled_tensor.scaling_mode]
return dq_func(scaled_tensor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment