Unverified Commit cf9a7c2f authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Refactor + MXFP8 + GroupedGEMM (#1627)



* refactor + mxfp8

* added grouped gemm

* rename linear to dense

* added cublas init phase for groupedGemm

* relax the tol of test encoder multiprocessing mxfp8 by 0.001
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent be055eb0
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused Layer normalization and dense layer transformation operations for Transformer Engine in JAX.
This module provides optimized implementations of layer normalization followed by
dense layer transformation (GEMM) operations, which are commonly used in transformer
architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints.
"""
from functools import partial
from typing import Tuple
import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize import (
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
)
def layernorm_dense(
x: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
bias: jnp.ndarray = None,
norm_type: str = "layernorm",
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
# The logic axes of sharding constraint to the layernorm input.
layernorm_input_axes: Tuple[str, ...] = None,
# The logic axes of sharding constraint to the dot input.
dot_input_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation.
This function implements the following sequence of operations:
1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
2. Linear transformation: y = x * kernel + bias
Args:
x: Input tensor with shape [batch..., hidden_in]
kernel: Weight matrix with shape [hidden_in, hidden_out]
gamma: Scale parameter for normalization with shape [hidden_in]
beta: Bias parameter for normalization with shape [hidden_in]
bias: Optional bias term for dense layer transformation with shape [hidden_out]
norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
quantizer_set: Set of quantizers for different tensor types
Returns:
Output tensor with shape [batch..., hidden_out]
Note:
- For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
must be False
- The function supports automatic differentiation through JAX's custom VJP
- Quantization is applied to both the normalized input and kernel
"""
output = _layernorm_dense(
x,
kernel,
gamma,
beta,
bias,
norm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes,
quantizer_set,
)
return output
@partial(
jax.custom_vjp,
nondiff_argnums=(
5,
6,
7,
8,
9,
),
)
def _layernorm_dense(
x: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
bias: jnp.ndarray,
norm_type: str,
zero_centered_gamma: bool,
epsilon: float,
layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...],
quantizer_set,
):
"""Internal implementation of layernorm_dense with custom VJP.
This function implements the forward pass of layernorm_dense with support for
automatic differentiation. It handles the normalization and dense layer transformation
operations, including quantization and sharding constraints.
Args:
x: Input tensor
kernel: Weight matrix
gamma: Scale parameter for normalization
beta: Bias parameter for normalization
bias: Optional bias term
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding
quantizer_set: Set of quantizers
Returns:
Output tensor from the combined operations
"""
output, _ = _layernorm_dense_fwd_rule(
x,
kernel,
gamma,
beta,
bias,
norm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes,
quantizer_set,
)
return output
def _layernorm_dense_fwd_rule(
x,
kernel,
gamma,
beta,
bias,
norm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes,
quantizer_set,
):
"""Forward pass rule for layernorm_dense.
Implements the forward pass computation including:
1. Layer normalization with quantization
2. Matrix multiplication with quantized kernel
3. Optional bias addition
4. Sharding constraints
Returns:
Tuple of (output, context) for automatic differentiation
"""
x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0]
assert len(kernel.shape) == 2 # Otherwise need to merge dims in quantize
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
casted_ln_out, mu, rsigma = tex.normalization_fwd(
x,
gamma,
beta,
zero_centered_gamma,
epsilon,
norm_type,
quantizer_set.x,
)
# Kernel in (hidden_in, hidden_out...)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...)
output = tex.gemm(
casted_ln_out.get_rowwise_tensor(),
casted_kernel.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
)
use_bias = bias is not None
if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
ctx = (
casted_ln_out.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None,
casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None,
x.shape,
kernel.shape,
mu,
rsigma,
x,
gamma,
beta,
x_contracting_dims,
k_contracting_dims,
use_bias,
quantizer_set,
)
return output, ctx
def _layernorm_dense_bwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
ctx,
grad,
):
"""Backward pass rule for layernorm_dense.
Implements the backward pass computation including:
1. Gradient computation for matrix multiplication
2. Gradient computation for layer normalization
3. Gradient computation for bias terms
4. Proper handling of quantization
Returns:
Tuple of gradients for all input parameters
"""
(
colwise_casted_ln_out,
rowwise_casted_kernel,
x_shape,
kernel_shape,
mu,
rsigma,
x,
gamma,
beta,
x_contracting_dims_in_fwd,
k_contracting_dims_in_fwd,
use_bias,
quantizer_set,
) = ctx
grad = with_sharding_constraint_by_logical_axes(grad, dot_input_axes)
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
)
# k_non_contracting_dims
k_constracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in k_contracting_dims_in_fwd
)
# NT GEMM
dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel,
(g_constracting_dim, k_constracting_dim),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
g_constracting_dim = x_constracting_dim = tuple(
range(0, len(x_shape) - len(x_contracting_dims_in_fwd))
)
# TN GEMM
wgrad = tex.gemm(
colwise_casted_ln_out,
casted_grad.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim),
)
dx, dgamma, dbeta = tex.normalization_bwd(
dgrad,
x,
mu,
rsigma,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
norm_type=norm_type,
)
return dx, wgrad, dgamma, dbeta, dbias, quantizer_set
_layernorm_dense.defvjp(_layernorm_dense_fwd_rule, _layernorm_dense_bwd_rule)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX MLP modules"""
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.
This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias
The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
from typing import List, Tuple, Sequence, Union, Callable
from functools import partial
......@@ -11,92 +21,81 @@ import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
"""
Activation Unit
"""
if len(activation_type) > 1:
assert x.shape[-2] == 2 # Linear + GeLU
output = _activation_lu(x, activation_type)
return output
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]):
_output, _ = _activation_lu_fwd_rule(x, activation_type)
return _output
def _activation_lu_fwd_rule(x, activation_type):
fwd_output = tex.act_lu(x, activation_type)
return fwd_output, (x,)
def _activation_lu_bwd_rule(activation_type, ctx, g):
(x,) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx,)
from .layernorm import canonicalize_norm_type
from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set
_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule)
def fused_layernorm_fp8_mlp(
def layernorm_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
kernels: List[jnp.ndarray],
biases: List[jnp.ndarray],
fp8_meta_pkgs: List[FP8MetaPackage],
layernorm_type: str,
norm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
layernorm_input_axes: Tuple[str, ...] = None,
norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
use_bias: bool = True,
quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
) -> jnp.ndarray:
"""Apply layer normalization followed by MLP block.
This function implements the following sequence of operations:
1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
2. First dense layer transformation: y1 = x * kernel1 + bias1
3. Activation function: y2 = activation(y1)
4. Second dense layer transformation: y3 = y2 * kernel2 + bias2
Args:
x: Input tensor with shape [batch..., hidden_in]
gamma: Scale parameter for normalization with shape [hidden_in]
beta: Bias parameter for normalization with shape [hidden_in]
kernels: List of two weight matrices:
- kernel1: [hidden_in, intermediate]
- kernel2: [intermediate, hidden_in]
biases: List of two bias terms:
- bias1: [intermediate]
- bias2: [hidden_in]
norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization
norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns:
Output tensor with shape [batch..., hidden_in]
Note:
- For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
must be False
- The function supports automatic differentiation through JAX's custom VJP
- Quantization is applied to both dense layer transformations
- Checkpointing is applied to both feed-forward networks for memory efficiency
"""
Layernorm + GEMM1 + bias + activation + GEMM2 + bias
"""
assert len(kernels) == 2
assert len(fp8_meta_pkgs) == len(kernels)
kernel_1 = kernels[0]
kernel_2 = kernels[1]
bias_1 = biases[0]
bias_2 = biases[1]
amax_list_1 = fp8_meta_pkgs[0].amax_list
amax_list_2 = fp8_meta_pkgs[1].amax_list
scale_list_1 = fp8_meta_pkgs[0].scale_list
scale_list_2 = fp8_meta_pkgs[1].scale_list
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == "rmsnorm":
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
norm_type = canonicalize_norm_type(norm_type)
if norm_type == "rmsnorm":
assert beta is None, "beta should be None if norm_type is 'rmsnorm'"
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
output = _fused_layernorm_fp8_mlp(
output = _layernorm_mlp(
x,
gamma,
beta,
......@@ -104,28 +103,22 @@ def fused_layernorm_fp8_mlp(
kernel_2,
bias_1,
bias_2,
amax_list_1,
amax_list_2,
scale_list_1,
scale_list_2,
fwd_dtype,
bwd_dtype,
layernorm_type,
norm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
use_bias,
quantizer_sets,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))
def _fused_layernorm_fp8_mlp(
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15))
def _layernorm_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
......@@ -133,24 +126,46 @@ def _fused_layernorm_fp8_mlp(
kernel_2: jnp.ndarray,
bias_1: jnp.ndarray,
bias_2: jnp.ndarray,
amax_list_1: List[jnp.ndarray],
amax_list_2: List[jnp.ndarray],
scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray],
fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype,
layernorm_type: str,
norm_type: str,
zero_centered_gamma: bool,
epsilon: float,
layernorm_input_axes: Tuple[str, ...],
norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str,
ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
use_bias: bool,
quantizer_sets,
):
output, _ = _fused_layernorm_fp8_mlp_fwd_rule(
"""Internal implementation of layernorm_mlp with custom VJP.
This function implements the forward pass of layernorm_mlp with support for
automatic differentiation. It handles the normalization, dense layer transformations,
activation, and quantization operations.
Args:
x: Input tensor
gamma: Scale parameter for normalization
beta: Bias parameter for normalization
kernel_1: First weight matrix
kernel_2: Second weight matrix
bias_1: First bias term
bias_2: Second bias term
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
norm_input_axes: Logical axes for layernorm sharding
dot_1_input_axes: Logical axes for first matrix multiplication sharding
dot_2_input_axes: Logical axes for second matrix multiplication sharding
ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s)
quantizer_sets: Tuple of quantizer sets
Returns:
Output tensor from the combined operations
"""
output, _ = _layernorm_mlp_fwd_rule(
x,
gamma,
beta,
......@@ -158,27 +173,21 @@ def _fused_layernorm_fp8_mlp(
kernel_2,
bias_1,
bias_2,
amax_list_1,
amax_list_2,
scale_list_1,
scale_list_2,
fwd_dtype,
bwd_dtype,
layernorm_type,
norm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
use_bias,
quantizer_sets,
)
return output
def _fused_layernorm_fp8_mlp_fwd_rule(
def _layernorm_mlp_fwd_rule(
x,
gamma,
beta,
......@@ -186,444 +195,257 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
kernel_2,
bias_1,
bias_2,
amax_list_1,
amax_list_2,
scale_list_1,
scale_list_2,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
layernorm_type,
norm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
use_bias,
quantizer_sets,
):
"""Forward pass rule for layernorm_mlp.
Implements the forward pass computation including:
1. Layer normalization with quantization
2. First matrix multiplication with quantized kernel
3. Activation function application
4. Second matrix multiplication with quantized kernel
5. Optional bias additions
6. Sharding constraints
7. Checkpointing for memory efficiency
Returns:
Tuple of (output, context) for automatic differentiation
"""
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
# Kernel_2 should be in shape of (Hidden_in, Hidden_out)
assert len(kernel_1.shape) == 3
assert kernel_1.shape[-2] == len(activation_type)
# Kernel_1 should be in shape of (hidden_in, activation_len * intermediate)
# Kernel_2 should be in shape of (intermediate, hidden_in)
assert len(kernel_1.shape) == 2
assert len(kernel_2.shape) == 2
assert kernel_1.shape[1] == kernel_2.shape[0] * len(activation_type)
x_contracting_dims = (len(x.shape) - 1,)
xt_batch_dims = tuple(range(1, x.ndim))
k_contracting_dims = (0,)
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0]
assert kernel_1.shape[-1] == kernel_2.shape[0]
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
assert kernel_1.shape[1] == len(activation_type) * kernel_2.shape[0]
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair(
*amax_list_1, *scale_list_1, *amax_list_2, *scale_list_2
)
amax_list_1 = maybe_fm32_to_fp32(*amax_list_1)
scale_list_1 = maybe_fm32_to_fp32(*scale_list_1)
amax_list_2 = maybe_fm32_to_fp32(*amax_list_2)
scale_list_2 = maybe_fm32_to_fp32(*scale_list_2)
fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
scale_list_1, scale_inv_list_1 = FP8MetaPackage.update_fp8_scale(
amax_list_1, scale_list_1, fp8_dtype_list
)
amax_list_1 = FP8MetaPackage.update_amax_list(amax_list_1)
scale_list_2, scale_inv_list_2 = FP8MetaPackage.update_fp8_scale(
amax_list_2, scale_list_2, fp8_dtype_list
)
amax_list_2 = FP8MetaPackage.update_amax_list(amax_list_2)
x_amax = amax_list_1[FP8MetaPackage.INPUT_IDX][0:1]
x_scale = scale_list_1[FP8MetaPackage.INPUT_IDX]
x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == "layernorm":
ln_out, mu, rsigma, updated_x_amax = tex.layernorm_fwd_fp8(
x,
gamma,
beta,
x_amax,
x_scale,
x_scale_inv,
out_dtype=fwd_dtype,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
else:
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(
x, gamma, x_amax, x_scale, x_scale_inv, out_dtype=fwd_dtype, epsilon=epsilon
)
mu = None
assert x.shape == ln_out.shape
kernel_1_amax = amax_list_1[FP8MetaPackage.WEIGHT_IDX][0:1]
kernel_1_scale = scale_list_1[FP8MetaPackage.WEIGHT_IDX]
kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX]
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_1, updated_kernel_1_amax = tex.cast_fp8(
kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype
use_bias_1 = bias_1 is not None
use_bias_2 = bias_1 is not None
x = with_sharding_constraint_by_logical_axes(x, norm_input_axes)
casted_ln_out, mu, rsigma = tex.normalization_fwd(
x,
gamma,
beta,
zero_centered_gamma,
epsilon,
norm_type,
quantizer=ffn1_quantizer_set.x,
)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(kernel_1, quantizer=ffn1_quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = fp8_dot_impl(
ln_out,
casted_kernel_1,
x_scale_inv,
kernel_1_scale_inv,
x.dtype,
(x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
dot_1_output = tex.gemm(
casted_ln_out.get_rowwise_tensor(),
casted_kernel_1.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
)
if use_bias:
if use_bias_1:
bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
else:
bias_1_shape = None
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
activation_lu_out_amax = amax_list_2[FP8MetaPackage.INPUT_IDX][0:1]
activation_lu_out_scale = scale_list_2[FP8MetaPackage.INPUT_IDX]
activation_lu_out_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
# (batch..., hidden_in) -> (batch..., hidden)
casted_activation_lu_out, updated_activation_lu_amax = tex.act_lu_fp8(
dot_1_output,
activation_lu_out_amax,
activation_lu_out_scale,
activation_lu_out_scale_inv,
fwd_dtype,
activation_type,
)
casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x)
casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
casted_activation_lu_out, dot_2_input_axes
)
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
kernel_2_scale = scale_list_2[FP8MetaPackage.WEIGHT_IDX]
kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)
casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel)
# NN GEMM
# (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = fp8_dot_impl(
casted_activation_lu_out,
casted_kernel_2,
activation_lu_out_scale_inv,
kernel_2_scale_inv,
x.dtype,
(x_contracting_dims, (0,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
dot_2_output = tex.gemm(
casted_act_out.get_rowwise_tensor(),
casted_kernel_2.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
)
if use_bias:
if use_bias_2:
bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
else:
bias_2_shape = None
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (
x,
ln_out,
mu,
rsigma,
gamma,
beta,
casted_ln_out.get_colwise_tensor(),
casted_kernel_1.get_rowwise_tensor(),
dot_1_output,
casted_activation_lu_out,
casted_kernel_1,
casted_kernel_2,
amax_list_1,
amax_list_2,
scale_list_1,
scale_list_2,
scale_inv_list_1,
scale_inv_list_2,
updated_x_amax,
updated_activation_lu_amax,
updated_kernel_1_amax,
updated_kernel_2_amax,
casted_act_out.get_colwise_tensor(),
casted_kernel_2.get_rowwise_tensor(),
x_contracting_dims,
xt_batch_dims,
bias_1_shape,
bias_2_shape,
maybe_fp32_to_fm32,
k_contracting_dims,
kernel_1.shape,
kernel_2.shape,
use_bias_1,
use_bias_2,
quantizer_sets,
)
return dot_2_output, ctx
def _fused_layernorm_fp8_mlp_bwd_rule(
fwd_dtype, # pylint: disable=unused-argument
bwd_dtype,
layernorm_type,
def _layernorm_mlp_bwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
layernorm_input_axes,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument
ffn2_ckpt_name, # pylint: disable=unused-argument
activation_type,
use_bias,
ctx,
grad,
):
"""Backward pass rule for layernorm_mlp.
Implements the backward pass computation including:
1. Gradient computation for second matrix multiplication
2. Gradient computation for activation function
3. Gradient computation for first matrix multiplication
4. Gradient computation for layer normalization
5. Gradient computation for bias terms
6. Proper handling of quantization
Returns:
Tuple of gradients for all input parameters
"""
(
x,
ln_out,
mu,
rsigma,
gamma,
beta,
colwise_casted_ln_out,
rowwise_casted_kernel_1,
dot_1_output,
casted_activation_lu_out,
casted_kernel_1,
casted_kernel_2,
amax_list_1,
amax_list_2,
scale_list_1,
scale_list_2,
scale_inv_list_1,
scale_inv_list_2,
updated_x_amax,
updated_activation_lu_amax,
updated_kernel_1_amax,
updated_kernel_2_amax,
x_contracting_dims,
xt_batch_dims,
bias_1_shape,
bias_2_shape,
maybe_fp32_to_fm32,
colwise_casted_act_out,
rowwise_casted_kernel_2,
x_contracting_dims_in_fwd,
k_contracting_dims_in_fwd,
kernel_1_shape,
kernel_2_shape,
use_bias_1,
use_bias_2,
quantizer_sets,
) = ctx
grad_amax = amax_list_2[FP8MetaPackage.GRAD_IDX][0:1]
grad_scale = scale_list_2[FP8MetaPackage.GRAD_IDX]
grad_scale_inv = scale_inv_list_2[FP8MetaPackage.GRAD_IDX]
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
# Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
if use_bias:
casted_grad, casted_grad_t, dbias_2, updated_grad_amax = tex.dbias_cast_transpose(
grad,
grad_amax,
grad_scale,
grad_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1,
)
dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
else:
casted_grad, casted_grad_t, updated_grad_amax = tex.cast_transpose(
grad,
grad_amax,
grad_scale,
grad_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1,
)
dbias_2 = None
casted_activation_lu_out_t = tex.transpose(
casted_activation_lu_out, static_axis_boundary=-1, transpose_axis_boundary=-1
casted_grad, dbias_2 = tex.quantize_dbias(
grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad
)
# (hidden, batch...,) x (hidden, batch...)
gemm2_x_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
wgrad_2 = fp8_dot_impl(
casted_activation_lu_out_t,
casted_grad_t,
gemm2_x_scale_inv,
grad_scale_inv,
grad.dtype,
(xt_batch_dims, xt_batch_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_2 = tuple(
range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
)
# k_non_contracting_dims
k_constracting_dim_2 = tuple(
dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
)
# NT GEMM
# (batch..., hidden_out) x (hidden_in, hidden_out)
kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
dgrad_2 = fp8_dot_impl(
casted_grad,
casted_kernel_2,
grad_scale_inv,
kernel_2_scale_inv,
grad.dtype,
(x_contracting_dims, (1,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
dgrad_2 = tex.gemm(
casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel_2,
(g_constracting_dim_2, k_constracting_dim_2),
)
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
dactivation_lu_amax = amax_list_1[FP8MetaPackage.GRAD_IDX][0:1]
dactivation_lu_scale = scale_list_1[FP8MetaPackage.GRAD_IDX]
dactivation_lu_scale_inv = scale_inv_list_1[FP8MetaPackage.GRAD_IDX]
if len(activation_type) > 1: # if gated
if use_bias:
dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type)
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = (
tex.dbias_cast_transpose(
dactivation_lu,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-2,
)
)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else:
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = (
tex.dgated_act_lu_cast_transpose(
dgrad_2,
dot_1_output,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
activation_type=activation_type,
)
)
dbias_1 = None
else:
if use_bias:
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = (
tex.dact_lu_dbias_cast_transpose(
dgrad_2,
dot_1_output,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
activation_type=activation_type,
)
)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else:
dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type)
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = (
tex.cast_transpose(
dactivation_lu,
dactivation_lu_amax,
dactivation_lu_scale,
dactivation_lu_scale_inv,
bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-2,
)
)
dbias_1 = None
ln_out_t = tex.transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
# (hidden, batch...) x (hidden, batch...)
gemm1_x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims)
wgrad_1 = fp8_dot_impl(
ln_out_t,
casted_dactivation_lu_t,
gemm1_x_scale_inv,
dactivation_lu_scale_inv,
grad.dtype,
(xt_batch_dims, xt_batch_dims_2),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
x_constracting_dim = g_constracting_dim = tuple(
range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
)
x_contracting_dims = (
(min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims),
(1, 2),
)
kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX]
dgrad_1 = fp8_dot_impl(
casted_dactivation_lu,
casted_kernel_1,
dactivation_lu_scale_inv,
kernel_1_scale_inv,
grad.dtype,
x_contracting_dims,
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
# TN GEMM
# (hidden, batch...,) x (hidden, batch...)
wgrad_2 = tex.gemm(
colwise_casted_act_out,
casted_grad.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim),
)
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)
if layernorm_type == "layernorm":
dx, dgamma, dbeta = tex.layernorm_bwd(
dgrad_1,
x,
mu,
rsigma,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
else:
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
dx, dgamma = tex.rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
amax_list_1[FP8MetaPackage.INPUT_IDX] = (
amax_list_1[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0])
)
amax_list_1[FP8MetaPackage.WEIGHT_IDX] = (
amax_list_1[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_1_amax[0])
casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
dgrad_2,
dot_1_output,
activation_type=activation_type,
is_dbias=use_bias_1,
quantizer=ffn2_quantizer_set.dgrad,
)
amax_list_1[FP8MetaPackage.GRAD_IDX] = (
amax_list_1[FP8MetaPackage.GRAD_IDX].at[0].set(updated_dactivation_lu_amax[0])
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_1 = tuple(
range(dgrad_2.ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dgrad_2.ndim)
)
amax_list_2[FP8MetaPackage.INPUT_IDX] = (
amax_list_2[FP8MetaPackage.INPUT_IDX].at[0].set(updated_activation_lu_amax[0])
# k_non_contracting_dims
k_constracting_dim_1 = tuple(
dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
)
amax_list_2[FP8MetaPackage.WEIGHT_IDX] = (
amax_list_2[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_2_amax)
# NT GEMM
dgrad_1 = tex.gemm(
casted_dact_out.get_rowwise_tensor(),
rowwise_casted_kernel_1,
(g_constracting_dim_1, k_constracting_dim_1),
)
amax_list_2[FP8MetaPackage.GRAD_IDX] = (
amax_list_2[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, norm_input_axes)
# TN GEMM
# (hidden, batch...) x (hidden, batch...)
wgrad_1 = tex.gemm(
colwise_casted_ln_out,
casted_dact_out.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim),
)
amax_list_1 = maybe_fp32_to_fm32(*amax_list_1)
scale_list_1 = maybe_fp32_to_fm32(*scale_list_1)
amax_list_2 = maybe_fp32_to_fm32(*amax_list_2)
scale_list_2 = maybe_fp32_to_fm32(*scale_list_2)
return (
dx,
dgamma,
dbeta,
wgrad_1,
wgrad_2,
dbias_1,
dbias_2,
amax_list_1,
amax_list_2,
scale_list_1,
scale_list_2,
dx, dgamma, dbeta = tex.normalization_bwd(
dgrad_1,
x,
mu,
rsigma,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
norm_type=norm_type,
)
return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets)
_fused_layernorm_fp8_mlp.defvjp(
_fused_layernorm_fp8_mlp_fwd_rule, _fused_layernorm_fp8_mlp_bwd_rule
)
_layernorm_mlp.defvjp(_layernorm_mlp_fwd_rule, _layernorm_mlp_bwd_rule)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Python interface for quantization helpers.
This module provides a high-level interface for tensor quantization in JAX,
including support for various scaling modes and quantization strategies.
It exports all the necessary classes and functions from the underlying
implementation modules.
"""
from .tensor import *
from .quantizer import *
from .dequantizer import *
from .scaling_modes import *
from .metadata import *
from .helper import *
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Dequantization utilities for TE/JAX.
This module provides utilities for dequantizing tensors that have been quantized
using various scaling modes, including delayed scaling and block scaling.
"""
import jax
import jax.numpy as jnp
from .scaling_modes import ScalingMode
__all__ = ["Dequantizer"]
class Dequantizer:
"""Encapsulation class for dequantization helpers.
This class provides static methods for dequantizing tensors that have been
quantized using different scaling modes. It supports both delayed scaling
and block scaling modes.
"""
@staticmethod
def _dq_func_tensor_scaling(scaled_tensor):
"""Dequantize a tensor using delayed scaling.
This function dequantizes a tensor that was quantized using delayed scaling
by multiplying the quantized data with the inverse scaling factor.
Args:
scaled_tensor: The quantized tensor to dequantize
Returns:
The dequantized tensor in the specified data type
"""
return jnp.asarray(
scaled_tensor.data.astype(jnp.float32) * scaled_tensor.scale_inv.astype(jnp.float32),
scaled_tensor.dq_dtype,
)
@staticmethod
def _dq_func_block_scaling(scaled_tensor):
"""Dequantize a tensor using block scaling.
This function dequantizes a tensor that was quantized using block scaling
by applying the inverse scaling factor to each block of data.
Args:
scaled_tensor: The quantized tensor to dequantize
Returns:
The dequantized tensor in the specified data type
"""
data = scaled_tensor.data.astype(jnp.float32)
data_shape = data.shape
scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32)
scale_shape = scaled_tensor.scaling_mode.get_scale_shape(
scaled_tensor.data.shape, scaled_tensor.is_colwise, is_padded=False
)
scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding
data = data.reshape(
*data_shape[:-2],
scale_shape[-2],
int(data_shape[-2] / scale_shape[-2]),
scale_shape[-1],
int(data_shape[-1] / scale_shape[-1]),
)
scale = jnp.expand_dims(scale, axis=(-1, -3))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape(
data_shape
)
funcs = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling,
ScalingMode.NVTE_MXFP8_1D_SCALING: _dq_func_block_scaling,
}
@staticmethod
def dequantize(scaled_tensor):
"""Dequantize a scaled tensor using the appropriate scaling mode.
This method selects the appropriate dequantization function based on the
scaling mode used for quantization and applies it to the tensor.
Args:
scaled_tensor: The quantized tensor to dequantize
Returns:
The dequantized tensor in the specified data type
"""
dq_func = Dequantizer.funcs[scaled_tensor.scaling_mode]
return dq_func(scaled_tensor)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Config module for quantization metadata management
This module provides configuration and helper functions for managing quantization metadata
in JAX, including support for different scaling modes and datatypes.
"""
from contextlib import contextmanager
from enum import Enum
from typing import Optional, Tuple, Dict, Union
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformer_engine_jax import DType
from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import (
get_cuda_version,
get_device_compute_capability,
)
from transformer_engine.common import recipe
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
from .scaling_modes import ScalingMode
from .. import cpp_extensions as tex
__all__ = ["QuantizeConfig", "fp8_autocast", "is_fp8_available", "update_collections"]
_is_fp8_available = None
_reason_for_no_fp8 = ""
Collection = Union[Dict, FrozenDict]
def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
"""Check if delayed scaling FP8 is supported on the given GPU architecture.
Args:
gpu_arch: The GPU architecture version
Returns:
A tuple of (bool, str) indicating support and any error message
"""
if gpu_arch >= 90: # hopper and above
return True, ""
if gpu_arch < 89: # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if get_cuda_version() < 12010:
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
return True, ""
def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
"""Check if block scaling FP8 is supported on the given GPU architecture.
Args:
gpu_arch: The GPU architecture version
Returns:
A tuple of (bool, str) indicating support and any error message
"""
if gpu_arch >= 100: # blackwell and above
return True, ""
if gpu_arch < 99: # pre-blackwell
return False, "Device compute capability 9.9 or higher required for MXFP8 execution."
if get_cublasLt_version() < 120800:
return False, "CublasLt version 12.8.0 or higher required for MXFP8 execution."
if get_cuda_version() < 12010:
return False, "Cuda version 12.8 or higher required for MXFP8 execution."
if not tex.jax_version_meet_requirement("0.5.3"):
return False, "Jax version 0.5.3 or higher required for MXFP8 execution."
return True, ""
def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
"""Check if FP8 is supported for the given scaling mode and GPU.
Args:
scaling_mode: The scaling mode to check support for
gpu_id: The ID of the GPU to check
Returns:
A tuple of (bool, str) indicating support and any error message
"""
gpu_arch = get_device_compute_capability(gpu_id)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
return _check_delayed_scaling_fp8_support(gpu_arch)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
return _check_block_scaling_fp8_support(gpu_arch)
return (False, "Unsupported scaling_mode!")
def is_fp8_available(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
gpu_id=None,
) -> Tuple[bool, str]:
"""Check if FP8 is available for the given scaling mode and GPU.
Args:
scaling_mode: The scaling mode to check availability for (default: DELAYED_TENSOR_SCALING)
gpu_id: Optional GPU ID to check specific device (default: None)
Returns:
A tuple of (bool, str) indicating availability and any error message
"""
if gpu_id is not None:
return _check_fp8_support(scaling_mode, gpu_id)
global _is_fp8_available, _reason_for_no_fp8
if _is_fp8_available is None:
_is_fp8_available = {}
_reason_for_no_fp8 = {}
if scaling_mode not in _is_fp8_available:
_is_fp8_available[scaling_mode] = True
_reason_for_no_fp8[scaling_mode] = ""
# JAX doesn't provide the local GPU id.
for local_gpu_id in range(len(jax.local_devices())):
ret, msg = _check_fp8_support(scaling_mode, local_gpu_id)
if ret is False:
_is_fp8_available[scaling_mode] = ret
_reason_for_no_fp8[scaling_mode] = msg
return ret, msg
return _is_fp8_available[scaling_mode], _reason_for_no_fp8[scaling_mode]
def _format2dtypes(format_: recipe.Format):
"""Convert recipe.Format.dtype to corresponding JAX dtypes.
Args:
format_: The FP8 format to convert
Returns:
A tuple of (forward_dtype, backward_dtype) for the given format
"""
if format_ == recipe.Format.E4M3:
return jnp.float8_e4m3fn, jnp.float8_e4m3fn
if format_ == recipe.Format.E5M2:
return jnp.float8_e5m2, jnp.float8_e5m2
if format_ == recipe.Format.HYBRID:
return jnp.float8_e4m3fn, jnp.float8_e5m2
return jnp.bfloat16, jnp.bfloat16
class AmaxComputeAlgo(Enum):
"""Enumeration for AMAX computation algorithms.
Attributes:
MAX: Use maximum value for AMAX computation
MOST_RECENT: Use most recent value for AMAX computation
"""
MAX = "max"
MOST_RECENT = "most_recent"
def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
"""Convert recipe.Recipe to ScalingMode.
Args:
fp8_recipe: The FP8 recipe to convert
Returns:
The corresponding ScalingMode
Raises:
ValueError: If the recipe type is not supported
"""
if isinstance(fp8_recipe, recipe.DelayedScaling):
return ScalingMode.NVTE_DELAYED_TENSOR_SCALING
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
return ScalingMode.NVTE_MXFP8_1D_SCALING
raise ValueError("Invalid fp8_recipe!")
def update_collections(new: Collection, original: Collection) -> Collection:
"""Update collections with new values while preserving original structure.
Args:
new: New collection of values to add/update
original: Original collection to update
Returns:
Updated collection with new values merged with original
Raises:
AssertionError: If either collection is not a dict or FrozenDict
"""
assert isinstance(original, (dict, FrozenDict))
assert isinstance(new, (dict, FrozenDict))
frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
for key in new:
if key in frozen_original:
frozen_original, _ = frozen_original.pop(key)
new_coll = FrozenDict({**new, **frozen_original})
if not isinstance(original, FrozenDict):
new_coll = new_coll.unfreeze()
return new_coll
class QuantizeConfig:
"""Configuration class for quantization settings.
This class manages global quantization settings including FP8 formats,
scaling modes, and accumulation settings.
Attributes:
INITIALIZED: Whether the config has been initialized
MARGIN: Margin value for quantization
COLLECTION_NAME: Name of the collection for quantization metadata
FP8_FORMAT: FP8 format to use
FWD_DTYPE: Forward pass data type
BWD_DTYPE: Backward pass data type
FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass
FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients
FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients
IF_QUANTIZE_2X: Whether 2x quantization is enabled
SCALING_MODE: Scaling mode
AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling
AMAX_COMPUTE_ALGO: Algorithm for AMAX computation
"""
INITIALIZED = False
MARGIN: float = 0.0
COLLECTION_NAME: str = "quantize_meta"
FP8_FORMAT: recipe.Format = recipe.Format.HYBRID
FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0]
BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1]
FP8_2X_ACC_FPROP: bool = False
FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False
IF_QUANTIZE_2X: bool = False
SCALING_MODE: ScalingMode = ScalingMode.NVTE_NO_SCALING
# DelayedScaling
AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
@staticmethod
def is_fp8_enabled():
"""Check if FP8 quantization is enabled.
Returns:
bool: True if quantization is enabled, False otherwise
"""
return QuantizeConfig.INITIALIZED
@classmethod
def initialize(cls, fp8_recipe: recipe.Recipe) -> None:
"""Initialize the quantization configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls.INITIALIZED = True
cls.MARGIN = fp8_recipe.margin
cls.FP8_FORMAT = fp8_recipe.fp8_format
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = _get_scaling_mode(fp8_recipe)
cls.IF_QUANTIZE_2X = True
@classmethod
def finalize(cls) -> None:
"""Reset the quantization configuration to default values."""
cls.INITIALIZED = False
cls.MARGIN = 0.0
cls.FP8_FORMAT = recipe.Format.HYBRID
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING
cls.FP8_2X_ACC_FPROP = False
cls.FP8_2X_ACC_DGRAD = False
cls.FP8_2X_ACC_WGRAD = False
cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING
cls.IF_QUANTIZE_2X = False
# DelayedScaling
cls.AMAX_HISTORY_LEN = 1024
cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
class DelayedScalingQuantizeConfig:
"""Configuration class for delayed scaling FP8 recipe.
This class provides specific initialization and finalization for delayed scaling
FP8 quantization mode.
"""
@staticmethod
def initialize(fp8_recipe: recipe.Recipe) -> None:
"""Initialize delayed scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
Raises:
AssertionError: If recipe parameters are not supported
"""
assert fp8_recipe.amax_compute_algo in [
"max",
"most_recent",
], "DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX."
assert (
fp8_recipe.scaling_factor_compute_algo is None
), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX."
assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX."
cls = QuantizeConfig
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len
string_to_amax_compute_algo = {
"max": AmaxComputeAlgo.MAX,
"most_recent": AmaxComputeAlgo.MOST_RECENT,
}
cls.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo]
cls.FP8_2X_ACC_DGRAD = True
cls.FP8_2X_ACC_WGRAD = True
@staticmethod
def finalize() -> None:
"""Reset the delayed scaling configuration."""
QuantizeConfig.finalize()
class BlockScalingQuantizeConfig:
"""Configuration class for block scaling FP8 recipe.
This class provides specific initialization and finalization for block scaling
FP8 quantization mode.
"""
@staticmethod
def initialize(fp8_recipe: recipe.Recipe) -> None:
"""Initialize block scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls = QuantizeConfig
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = 0
@staticmethod
def finalize() -> None:
"""Reset the block scaling configuration."""
QuantizeConfig.finalize()
@contextmanager
def fp8_autocast(
enabled: bool = False,
fp8_recipe: Optional[recipe.Recipe] = None,
mesh_resource: Optional[MeshResource] = None,
) -> None:
r"""Context manager for FP8 automatic mixed precision.
This context manager enables FP8 quantization for the duration of its context.
.. code-block:: python
mesh_shape = (4, 2)
dp_mesh_axis_name = 'data_parallel'
tp_mesh_axis_name = 'tensor_parallel'
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
rules = extend_logical_axis_rules(tuple())
transformer = TransformerLayer()
with partitioning.axis_rules(rules):
pjit(transformer.init, ...)(...)
.. note::
We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`,
and :attr:`amax_compute_algo` (with value 'max' and 'most_recent') in
recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling
will trigger an assertion.
Parameters
----------
enabled: bool, default = False
Whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None
Recipe used for FP8 training.
mesh_resource: MeshResource, default = None
Specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then no data or tensor parallelism will be used.
"""
if fp8_recipe is None:
fp8_recipe = recipe.DelayedScaling()
if mesh_resource is None:
mesh_resource = MeshResource()
Config = DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
Config = BlockScalingQuantizeConfig
try:
with global_shard_guard(mesh_resource):
if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available(_get_scaling_mode(fp8_recipe))
assert fp8_available, reason_for_no_fp8
Config.initialize(fp8_recipe)
yield
finally:
Config.finalize()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Metadata classes for quantization in JAX.
This module provides classes for managing quantization metadata, including
scale factors and amax history for different tensor types.
"""
from dataclasses import dataclass
import jax.numpy as jnp
__all__ = ["QuantizeMeta", "QuantizeMetaSet"]
@dataclass
class QuantizeMeta:
"""Metadata for quantization parameters.
Attributes:
scale: The scaling factor for quantization
amax_history: History of maximum absolute values
"""
scale: jnp.ndarray
amax_history: jnp.ndarray
@dataclass
class QuantizeMetaSet:
"""Set of quantization metadata for different tensor types.
Attributes:
x: Quantization metadata for input tensors
kernel: Quantization metadata for kernel tensors
grad: Quantization metadata for gradient tensors
"""
x: QuantizeMeta
kernel: QuantizeMeta
grad: QuantizeMeta
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor quantization classes for TE/JAX.
This module provides classes and utilities for quantizing tensors in JAX.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import partial
from typing import Union, Optional
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeAxis
from .scaling_modes import ScalingMode
from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
from .helper import (
QuantizeConfig,
AmaxComputeAlgo,
)
__all__ = [
"QuantizeAxis",
"Quantizer",
"QuantizerSet",
"DelayedScaleQuantizer",
"BlockScaleQuantizer",
"QuantizerFactory",
"noop_quantizer_set",
]
@register_pytree_node_class
@dataclass
class Quantizer(ABC):
"""Base class for quantizers.
This abstract class defines the interface for tensor quantization, providing
methods for quantization and scale management.
Attributes:
q_dtype: The data type for quantized values
scaling_mode: The scaling mode to use for quantization
q_axis: The quantization axis (row-wise, column-wise, or both)
"""
q_dtype: jnp.dtype
scaling_mode: ScalingMode
q_axis: QuantizeAxis
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = ()
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis)
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Reconstruct a quantizer from its flattened representation.
Args:
aux_data: Auxiliary data containing quantizer parameters
children: Unused children data
Returns:
A reconstructed Quantizer instance
"""
return cls(*aux_data, *children)
def update(self, *args, **kwargs):
"""Update quantizer state (no-op in base class)."""
del args, kwargs
def is_2x2x(self) -> bool:
"""Check if quantizer uses both row-wise and column-wise quantization.
Returns:
True if using both row-wise and column-wise quantization
"""
return self.q_axis == QuantizeAxis.ROWWISE_COLWISE
@abstractmethod
def get_layout(self) -> str:
"""Get the data layout.
Returns:
Data layout in string format
"""
@abstractmethod
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
"""Core quantization function to be implemented by subclasses.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values, default is x.dtype
Returns:
A ScaledTensor1x containing the quantized data
"""
def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None):
"""Quantize a tensor using the internal _quantize_func().
Args:
x: Input tensor to quantize
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
if (is_rowwise and is_colwise) or self.is_2x2x():
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype)
colwise_tensor = self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
if is_colwise:
return self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype)
return self._quantize_func(x, dq_dtype=dq_dtype)
def get_scale_shapes(self, data_shape, is_padded=True):
"""Get shapes for scale tensors.
Args:
data_shape: Shape of the input tensor
is_padded: Whether to use padded shapes
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded)
def get_scale_dtype(self):
"""Get the data type for scale tensors.
Returns:
The data type for scale tensors
"""
return self.scaling_mode.get_scale_dtype()
@register_pytree_node_class
@dataclass
class DelayedScaleQuantizer(Quantizer):
"""Quantizer implementation using delayed scaling.
This quantizer uses delayed scaling mode with float32 scales and maintains
a history of maximum absolute values for dynamic scaling.
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field(
default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32)
)
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis)
return (children, aux_data)
def get_layout(self) -> str:
"""Get the data layout string.
Returns:
Data layout in string format
Raises:
ValueError: If quantization axis is invalid
"""
layout = "NT"
if self.q_axis == QuantizeAxis.ROWWISE_COLWISE:
return layout
if self.q_axis == QuantizeAxis.ROWWISE:
return layout[0]
if self.q_axis == QuantizeAxis.COLWISE:
return layout[1]
raise ValueError(f"Invalid q_axis: {self.q_axis}")
def _quantize_func(self, x: jnp.ndarray, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
"""Quantize function helper for delayed scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
Returns:
A ScaledTensor1x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
compute_dtype = self.scale.dtype
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
scaled_x = x.astype(compute_dtype) * self.scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype
# dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
# scaled_x = x * self.scale.astype(compute_dtype)
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / self.scale
self.update(jnp.max(jnp.abs(x)).reshape((1,)))
return ScaledTensorFactory.create_1x(
data=clipped_scaled_x,
scale_inv=scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
)
def quantize(self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None):
"""Quantize a tensor using the internal _quantize_func().
Args:
x: Input tensor to quantize
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
is_rowwise = (
is_rowwise
if is_rowwise is not None
else (self.q_axis == QuantizeAxis.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_axis == QuantizeAxis.COLWISE or self.is_2x2x())
)
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype)
colwise_tensor = None
if is_colwise:
colwise_tensor = ScaledTensorFactory.create_1x(
data=jnp.transpose(rowwise_tensor.data, (-1, *range(rowwise_tensor.data.ndim - 1))),
scale_inv=rowwise_tensor.scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
is_colwise=True,
layout="T",
)
if is_colwise and is_rowwise:
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
if is_colwise:
return colwise_tensor
return rowwise_tensor
@staticmethod
@jax.jit
def _update_amax_history(amax_history, new_amax):
"""Update AMAX history with new maximum value.
Args:
amax_history: Current AMAX history
new_amax: New maximum value to add
Returns:
Updated AMAX history
"""
amax_history = amax_history.at[0].set(new_amax[0])
return amax_history
@staticmethod
@partial(jax.jit, static_argnums=(2,))
def _compute_scale(amax_history, scale, q_dtype):
"""Compute new scale based on AMAX history.
Args:
amax_history: History of maximum absolute values
scale: Current scale
q_dtype: Quantization data type
Returns:
Updated scale value
"""
# 2. Calculate the current scale
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(amax_history, axis=-1, keepdims=True)
else:
amax = amax_history[0:1]
sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
scale = scale.at[0].set(sf[0])
return scale
@staticmethod
@jax.jit
def _roll_and_reset_amax_history(amax_history):
"""Roll AMAX history and reset first element.
Args:
amax_history: Current AMAX history
Returns:
Updated AMAX history
"""
updated_amax_history = jnp.roll(amax_history, -1, -1)
amax_history = updated_amax_history.at[0].set(0.0)
return amax_history
def update(self, new_amax: jnp.ndarray):
"""Update AMAX history and compute new scale.
Args:
new_amax: New maximum absolute value to add to history
"""
amax_history = self._update_amax_history(self.amax_history, new_amax)
self.scale = self._compute_scale(amax_history, self.scale, self.q_dtype)
self.amax_history = self._roll_and_reset_amax_history(amax_history)
@register_pytree_node_class
@dataclass
class BlockScaleQuantizer(Quantizer):
"""Quantizer implementation using block-based scaling.
This quantizer uses block scaling mode with FP8 scales and block-based
quantization for improved efficiency.
Attributes:
scaling_mode: Set to NVTE_MXFP8_1D_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE)
"""
scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE
def get_layout(self) -> str:
"""Get the data layout string.
Returns:
Data layout in string format
"""
if self.is_2x2x():
return "NN"
return "N"
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
"""Quantize function helper for block scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
Returns:
A ScaledTensor1x containing the quantized data
"""
# TODO(Phuong): use quantize_func from JAX
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
x_shape = x.shape
scale_shape = self.scaling_mode.get_scale_shape(x_shape, is_colwise, is_padded=False)
scale_dtype = self.scaling_mode.get_scale_dtype()
x = x.reshape(
*x_shape[:-2],
scale_shape[-2],
int(x_shape[-2] / scale_shape[-2]),
scale_shape[-1],
int(x_shape[-1] / scale_shape[-1]),
)
amax = jnp.max(jnp.abs(x), axis=(-3, -1), keepdims=True)
MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32)
scales = amax.astype(jnp.float32) / MAX
scales_q = self._cast_to_e8m0_with_rounding_up(scales)
scaled_x = x / self._e8m0_to_dtype(scales_q, jnp.float32)
clipped_x = jnp.clip(scaled_x, -MAX, MAX)
x_q = clipped_x.astype(self.q_dtype).reshape(x_shape)
scales_q = scales_q.reshape(scale_shape).view(scale_dtype)
return ScaledTensorFactory.create_1x(
x_q,
scales_q,
self.scaling_mode,
is_colwise=is_colwise,
dq_dtype=dq_dtype,
)
def _cast_to_e8m0_with_rounding_up(self, scales):
"""Cast scales to E8M0 format with rounding up.
Args:
scales: Input scales to convert
Returns:
Scales in E8M0 format
"""
temp = scales.astype(jnp.float32).view(jnp.uint32)
exp = temp >> 23
mant = temp & 0x7FFFFF
is_ru = jnp.logical_and(
jnp.logical_and((mant > 0), (exp != 0xFE)),
~jnp.logical_and((exp == 0), (mant <= 0x400000)),
)
exp = jnp.where(is_ru, exp + 1, exp)
new_scales = exp.astype(jnp.uint8)
return new_scales
def _e8m0_to_dtype(self, x, dtype):
"""Convert E8M0 format to specified data type.
Args:
x: Input in E8M0 format
dtype: Target data type
Returns:
Converted values in target data type
"""
temp = x.astype(jnp.uint32)
exp = temp << 23
new_x = exp.view(jnp.float32)
near_zero_value = 2**-15 if dtype == jnp.float16 else 2**-127
new_x = jnp.where(new_x == 0, jnp.array(near_zero_value, jnp.float32), new_x)
return new_x.astype(dtype)
@register_pytree_node_class
@dataclass
class QuantizerSet:
"""Set of quantizers for different tensor types.
This class manages quantizers for input tensors, kernel tensors, and
gradient tensors.
Attributes:
x: Quantizer for input tensors
kernel: Quantizer for kernel tensors
dgrad: Quantizer for gradient tensors
"""
x: Optional[Quantizer]
kernel: Optional[Quantizer]
dgrad: Optional[Quantizer]
def tree_flatten(self):
"""Flatten the quantizer set for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = (self.x, self.kernel, self.dgrad)
aux_data = ()
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Reconstruct a quantizer set from its flattened representation.
Args:
aux_data: Unused auxiliary data
children: Tuple of quantizers
Returns:
A reconstructed QuantizerSet instance
"""
return cls(*aux_data, *children)
@dataclass
class QuantizerFactory:
"""Factory class for creating quantizers.
This class provides static methods to create individual quantizers and
sets of quantizers with various configurations.
"""
quantizer_type_map = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScaleQuantizer,
ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScaleQuantizer,
}
@staticmethod
def create(
n_quantizers: int = 1,
scaling_mode: ScalingMode = None,
q_dtype: jnp.dtype = None,
q_axis: QuantizeAxis = None,
**kwargs,
) -> Quantizer:
"""Create one or more quantizers with specified parameters.
Args:
n_quantizers: Number of quantizers to create
scaling_mode: Scaling mode to use
q_dtype: Quantization data type
q_axis: Quantization axis
**kwargs: Additional arguments for quantizer initialization
Returns:
A single quantizer or tuple of quantizers
"""
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
# assert scaling_mode != ScalingMode.NVTE_INVALID_SCALING
if scaling_mode in (ScalingMode.NVTE_NO_SCALING, ScalingMode.NVTE_INVALID_SCALING):
quantizers = [None] * n_quantizers
else:
quantizers = []
for _ in range(n_quantizers):
quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)
quantizers.append(
quantizer_type(
q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis, **kwargs
)
)
return quantizers[0] if len(quantizers) == 1 else tuple(quantizers)
@staticmethod
def _create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) -> QuantizerSet:
"""Create a set of quantizers for forward and backward passes.
Args:
scaling_mode: Scaling mode to use
fwd_dtype: Data type for forward pass
bwd_dtype: Data type for backward pass
is_2x2x: Whether to use 2x2x quantization
**kwargs: Additional arguments for quantizer initialization
Returns:
A QuantizerSet instance
"""
if is_2x2x:
q_axis_x = q_axis_kernel = q_axis_dgrad = QuantizeAxis.ROWWISE_COLWISE
else:
q_axis_x = QuantizeAxis.ROWWISE
q_axis_kernel = QuantizeAxis.COLWISE
q_axis_dgrad = None
if "quantize_meta_set" in kwargs:
quantize_meta_set = kwargs.get("quantize_meta_set")
args_x = {
"scale": quantize_meta_set.x.scale,
"amax_history": quantize_meta_set.x.amax_history,
}
args_kernel = {
"scale": quantize_meta_set.kernel.scale,
"amax_history": quantize_meta_set.kernel.amax_history,
}
args_grad = {
"scale": quantize_meta_set.grad.scale,
"amax_history": quantize_meta_set.grad.amax_history,
}
else:
args_x = args_kernel = args_grad = {}
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_x, **args_x)
q_kernel = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_kernel, **args_kernel)
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_axis_dgrad, **args_grad)
return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)
@staticmethod
def create_set(
n_quantizer_sets: int = 1,
scaling_mode: ScalingMode = None,
fwd_dtype: jnp.dtype = None,
bwd_dtype: jnp.dtype = None,
is_2x2x: bool = None,
**kwargs,
) -> tuple[Union[tuple[Quantizer], None]]:
"""Create one or more sets of quantizers.
Args:
n_quantizer_sets: Number of quantizer sets to create
scaling_mode: Scaling mode to use, default is QuantizeConfig.SCALING_MODE
fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
**kwargs: Additional arguments for quantizer initialization
Returns:
A single quantizer set or tuple of quantizer sets
"""
scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE
bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE
is_2x2x = is_2x2x or QuantizeConfig.IF_QUANTIZE_2X
q_set = []
for _ in range(n_quantizer_sets):
q_set.append(
QuantizerFactory._create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs)
)
return q_set[0] if len(q_set) == 1 else tuple(q_set)
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NVTE_NO_SCALING)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Scaling mode implementations for quantization in JAX.
This module provides implementations of different scaling modes for tensor quantization,
including delayed scaling and block scaling strategies.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Tuple, Dict
from functools import reduce
import operator
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp
__all__ = ["ScalingMode"]
class ScalingModeMetadataImpl(ABC):
"""Base class for scaling mode implementations.
This abstract class defines the interface for different scaling mode implementations,
providing methods to get scale data types and shapes.
"""
@abstractmethod
def get_scale_dtype(self) -> jnp.dtype:
"""Get the data type for scale tensors.
Returns:
The data type used for scale tensors
"""
@abstractmethod
def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True
) -> Tuple[int, ...]:
"""Get the shape for scale tensors.
Args:
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
Returns:
The shape for scale tensors
"""
class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for delayed scaling mode.
This implementation provides metadata for delayed scaling mode, including scale data type and shape.
"""
def get_scale_dtype(self) -> jnp.dtype:
"""Get the data type for scale tensors in delayed scaling.
Returns:
The data type used for scale tensors (float32)
"""
return jnp.float32
def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True
) -> Tuple[int, ...]:
"""Get the shape for scale tensors in delayed scaling.
Args:
data_shape: The shape of the tensor being scaled
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
Returns:
The shape for scale tensors - (1,)
"""
del data_shape, is_colwise
return (1,)
class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for block scaling mode.
This implementation provides metadata for block scaling mode, which uses
block-based scaling with specific alignment requirements.
Attributes:
_block_dims: Dimensions of the scaling blocks
_block_alignment: Alignment requirements for blocks
"""
def __init__(self, block_dims: Tuple[int]):
"""Initialize block scaling mode implementation.
Args:
block_dims: Dimensions of the scaling blocks
"""
self._block_dims = block_dims
self._block_alignment = (128, 4)
def get_scale_dtype(self) -> jnp.dtype:
"""Get the data type for scale tensors in block scaling.
Returns:
The data type used for scale tensors (float8_e8m0fnu)
"""
return jnp.float8_e8m0fnu
def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True
) -> Tuple[int, ...]:
"""Get the shape for scale tensors in block scaling.
Args:
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
Returns:
The shape for scale tensors
"""
block_alignment = self._block_alignment if is_padded else (1, 1)
if is_colwise:
block_y, block_x = self._block_dims
alignment_y, alignment_x = block_alignment
else:
block_x, block_y = self._block_dims
alignment_x, alignment_y = block_alignment
seq_axis = len(data_shape) - 2
assert (
data_shape[seq_axis] % block_x == 0
), f"Input data of shape {data_shape} should be padded by {block_x} in axes={seq_axis}"
assert (
data_shape[-1] % block_y == 0
), f"Input data of shape {data_shape} should be padded by {block_y} in axis -1"
# NOTE: this overpads if dim > 2 and dims before seq_axis are greater than 1
n_block_seq = data_shape[seq_axis] // block_x
n_block_y = data_shape[-1] // block_y
n_flat_first_dim = reduce(operator.mul, data_shape[:seq_axis], 1) * n_block_seq
# Padding
n_flat_first_dim = ((n_flat_first_dim + alignment_x - 1) // alignment_x) * alignment_x
n_block_y = ((n_block_y + alignment_y - 1) // alignment_y) * alignment_y
out_shape = ()
for i in range(seq_axis):
d = data_shape[i]
out_shape += (d,)
assert n_flat_first_dim % d == 0
n_flat_first_dim //= d
out_shape += (n_flat_first_dim, n_block_y)
return out_shape
# (Phuong: Map the NVTEScalingMode value to the ScalingMode
@dataclass(frozen=True)
@register_pytree_node_class
class ScalingMode(Enum):
"""Enumeration of tensor scaling modes with their corresponding metadata implementations.
This class defines the available scaling modes for tensor quantization:
- NVTE_DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- NVTE_MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- NVTE_INVALID_SCALING: Invalid scaling mode
- NVTE_NO_SCALING: No scaling applied
"""
NVTE_DELAYED_TENSOR_SCALING = 0
NVTE_MXFP8_1D_SCALING = 1
NVTE_INVALID_SCALING = 2
NVTE_NO_SCALING = 3
def _get_impl(self) -> ScalingModeMetadataImpl:
"""Get the implementation for this scaling mode.
Returns:
The scaling mode implementation
Raises:
ValueError: If the scaling mode is invalid
"""
impl = SCALING_MODES_TO_IMPL.get(self)
if impl is None:
raise ValueError("Invalid scaling mode")
return impl
def get_scale_dtype(self):
"""Get the data type for scale tensors in this mode.
Returns:
The data type for scale tensors
"""
return self._get_impl().get_scale_dtype()
def get_scale_shape_2x(self, data_shape, is_padded=True) -> Tuple[Tuple[int]]:
"""Get shapes for both row-wise and column-wise scaling.
Args:
data_shape: Shape of the data tensor
is_padded: Whether to use padded shapes
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
rowwise_scale_shape = self.get_scale_shape(
data_shape, is_colwise=False, is_padded=is_padded
)
colwise_scale_shape = self.get_scale_shape(data_shape, is_colwise=True, is_padded=is_padded)
return (rowwise_scale_shape, colwise_scale_shape)
def get_scale_shape(self, data_shape, is_colwise, is_padded=True) -> Tuple[int]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
Returns:
The shape for scale tensors
"""
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded)
def __eq__(self, other):
"""Compare this scaling mode with another.
Args:
other: The other scaling mode to compare with
Returns:
True if the modes are equal, False otherwise
"""
if not isinstance(other, ScalingMode):
return False
return self.value == other.value
def tree_flatten(self):
"""Flatten this scaling mode for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
return (), (self.value)
@classmethod
def tree_unflatten(cls, aux_data, _children):
"""Reconstruct a scaling mode from its flattened representation.
Args:
aux_data: Auxiliary data containing the mode value
_children: Unused children data
Returns:
A reconstructed ScalingMode instance
"""
return cls(aux_data)
SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(),
ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)),
# WAR
ScalingMode.NVTE_NO_SCALING: DelayedScalingModeMetadataImpl(),
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor classes for TE/JAX
This module provides tensor classes for handling quantized tensors in JAX, including
both single-scale (1x) and double-scale (2x) quantization schemes. It supports
rowwise and colwise quantization modes with proper scaling and dequantization.
"""
from dataclasses import dataclass
from typing import Callable, Tuple
from abc import ABC, abstractmethod
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeAxis
from .scaling_modes import ScalingMode
from .dequantizer import Dequantizer
from ..sharding import (
with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
)
__all__ = [
"ScaledTensor",
"ScaledTensor1x",
"ScaledTensor2x",
"ScaledTensorFactory",
"with_sharding_constraint_by_logical_axes",
]
@register_pytree_node_class
@dataclass
class ScaledTensor(ABC):
"""Abstract base class for scaled tensors.
This class defines the interface for all scaled tensor implementations,
providing methods for dequantization and accessing row/column-wise components.
"""
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Reconstructs the tensor from its flattened representation.
Args:
aux_data: Auxiliary data needed for reconstruction
children: The flattened tensor components
Returns:
A reconstructed tensor instance
"""
return cls(*children, *aux_data)
@abstractmethod
def dequantize(self):
"""Dequantizes the tensor back to its original precision.
Returns:
The dequantized tensor
"""
@abstractmethod
def get_rowwise_tensor(self):
"""Returns the row-wise component of the tensor.
Returns:
The row-wise tensor component
Raises:
ValueError: If called on a tensor that doesn't support row-wise access
"""
@abstractmethod
def get_colwise_tensor(self):
"""Returns the column-wise component of the tensor.
Returns:
The column-wise tensor component
Raises:
ValueError: If called on a tensor that doesn't support column-wise access
"""
@register_pytree_node_class
@dataclass
class ScaledTensor1x(ScaledTensor):
"""Single-scale quantized tensor implementation.
This class represents a tensor quantized with a single scaling factor,
supporting both row-wise and column-wise quantization modes.
Attributes:
data: The quantized tensor data
scale_inv: The inverse scaling factors
scaling_mode: The scaling mode used for quantization
dq_dtype: The data type for dequantized values
_dq_func: The dequantization function
is_colwise: Whether the tensor uses column-wise quantization
layout: The layout specification for the tensor
"""
data: jnp.ndarray
scale_inv: jnp.ndarray
scaling_mode: ScalingMode
dq_dtype: jnp.dtype
_dq_func: Callable
is_colwise: bool
layout: str
def __post_init__(self):
"""Validates and adjusts the scale_inv shape after initialization.
Ensures the scale_inv shape matches the expected shape based on the scaling mode
and quantization direction. Pads the scale_inv if necessary.
"""
expected_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=True
)
expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=False
)
if self.scale_inv.shape != expected_scale_shape:
assert self.scale_inv.shape == expected_unpadded_scale_shape, (
f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded"
f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got"
f" {self.scale_inv.shape}"
)
pad_width = tuple(
(0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape)
)
# This actually pad scale_inv with nan, should we pad it with 127 directly instead?
self.scale_inv = jnp.pad(
self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0
)
def tree_flatten(self):
"""Flattens the tensor for JAX tree operations.
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children = (self.data, self.scale_inv)
aux_data = (self.scaling_mode, self.dq_dtype, self._dq_func, self.is_colwise, self.layout)
return (children, aux_data)
def dequantize(self):
"""Dequantizes the tensor using the stored dequantization function.
Returns:
The dequantized tensor
"""
return self._dq_func(self)
def get_rowwise_tensor(self):
"""Returns the tensor if it's row-wise quantized.
Returns:
The row-wise tensor
Raises:
ValueError: If called on a column-wise quantized tensor
"""
if not self.is_colwise:
return self
raise ValueError("Calling get_rowwise_tensor() from a colwise ScaledTensor1x!")
def get_colwise_tensor(self):
"""Returns the tensor if it's column-wise quantized.
Returns:
The column-wise tensor
Raises:
ValueError: If called on a row-wise quantized tensor
"""
if self.is_colwise:
return self
raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!")
@register_pytree_node_class
@dataclass
class ScaledTensor2x(ScaledTensor):
"""Double-scale quantized tensor implementation.
This class represents a tensor quantized with both row-wise and column-wise scaling factors.
Attributes:
rowwise_tensor: The row-wise quantized component
colwise_tensor: The column-wise quantized component
"""
rowwise_tensor: ScaledTensor1x
colwise_tensor: ScaledTensor1x
def tree_flatten(self):
"""Flattens the tensor for JAX tree operations.
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children = (self.rowwise_tensor, self.colwise_tensor)
aux_data = ()
return (children, aux_data)
def dequantize(self):
"""Dequantizes the tensor using the row-wise component's dequantization.
Returns:
The dequantized tensor
"""
return self.rowwise_tensor.dequantize()
def get_rowwise_tensor(self):
"""Returns the row-wise quantized component.
Returns:
The row-wise tensor component
"""
return self.rowwise_tensor
def get_colwise_tensor(self):
"""Returns the column-wise quantized component.
Returns:
The column-wise tensor component
"""
return self.colwise_tensor
@dataclass
class ScaledTensorFactory:
"""Factory class for creating scaled tensor instances.
Provides static methods to create both single-scale (1x) and double-scale (2x)
quantized tensors with various configurations.
"""
@staticmethod
def create_1x(
data, scale_inv, scaling_mode, dq_dtype=jnp.bfloat16, is_colwise=False, layout="N"
):
"""Creates a single-scale quantized tensor.
Args:
data: The quantized tensor data
scale_inv: The inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
is_colwise: Whether to use column-wise quantization (default: False)
layout: The layout specification (default: "N")
Returns:
A ScaledTensor1x instance
"""
dq_func = Dequantizer.funcs.get(scaling_mode)
return ScaledTensor1x(data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, layout)
@staticmethod
def create_2x(
data,
scale_inv,
colwise_data,
colwise_scale_inv,
scaling_mode,
dq_dtype=jnp.bfloat16,
layout="NN",
):
"""Creates a double-scale quantized tensor.
Args:
data: The row-wise quantized data
scale_inv: The row-wise inverse scaling factors
colwise_data: The column-wise quantized data
colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN")
Returns:
A ScaledTensor2x instance
"""
dq_func = Dequantizer.funcs.get(scaling_mode)
rowwise_tensor = ScaledTensor1x(
data,
scale_inv,
scaling_mode,
dq_dtype,
dq_func,
is_colwise=False,
layout=layout[0],
)
colwise_tensor = ScaledTensor1x(
colwise_data,
colwise_scale_inv,
scaling_mode,
dq_dtype,
dq_func,
is_colwise=True,
layout=layout[1],
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
@staticmethod
def create(
data: jnp.ndarray,
scale_inv: jnp.ndarray,
colwise_data: jnp.ndarray,
colwise_scale_inv: jnp.ndarray,
scaling_mode: ScalingMode,
dq_dtype: jnp.dtype = jnp.bfloat16,
layout: str = "NN",
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE,
):
"""Creates a scaled tensor based on the quantization axis.
Args:
data: The quantized tensor data
scale_inv: The inverse scaling factors
colwise_data: The column-wise quantized data
colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN")
q_axis: The quantization axis (default: ROWWISE)
Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_axis
"""
if q_axis == QuantizeAxis.ROWWISE_COLWISE:
return ScaledTensorFactory.create_2x(
data,
scale_inv,
colwise_data,
colwise_scale_inv,
scaling_mode,
dq_dtype,
layout=layout,
)
is_colwise = q_axis == QuantizeAxis.COLWISE
return ScaledTensorFactory.create_1x(
data, scale_inv, scaling_mode, dq_dtype, is_colwise=is_colwise, layout=layout[0]
)
def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
x: The tensor to apply sharding constraints to
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if isinstance(x, ScaledTensor1x):
return ScaledTensor1x(
data=with_sharding_constraint_by_logical_axes(x.data, logical_axis_names),
scale_inv=x.scale_inv,
scaling_mode=x.scaling_mode,
dq_dtype=x.dq_dtype,
_dq_func=x._dq_func,
is_colwise=x.is_colwise,
layout=x.layout,
)
if isinstance(x, ScaledTensor2x):
return ScaledTensor2x(
rowwise_tensor=with_sharding_constraint_by_logical_axes(
x.rowwise_tensor, logical_axis_names
),
colwise_tensor=with_sharding_constraint_by_logical_axes(
x.colwise_tensor, logical_axis_names
),
)
return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)
......@@ -2,7 +2,22 @@
#
# See LICENSE for license information.
"""Installation script for TE jax extensions."""
"""Installation script for Transformer Engine JAX extensions.
This module handles the build and installation of the JAX-specific components
of Transformer Engine. It manages:
- JAX extension compilation with pybind11
- Common header file management
- Build tool dependencies
- Package metadata and dependencies
The script supports both development and release builds, with different
behaviors for:
- Build tool management
- Header file copying
- Extension compilation
- Package distribution
"""
# pylint: disable=wrong-import-position,wrong-import-order
......@@ -41,6 +56,34 @@ CMakeBuildExtension = get_build_ext(BuildExtension, True)
if __name__ == "__main__":
"""Main entry point for JAX extension installation.
This section handles:
1. Common header file management
- Creates a temporary directory for common headers
- Copies necessary header files from the common library
2. Extension module setup
- Configures the JAX-specific C++ extension
- Sets up build paths and dependencies
3. Package configuration
- Sets package metadata
- Configures build and install requirements
- Sets up extension modules
4. Cleanup
- Removes temporary directories after build
- Cleans up build tools if not in release mode
Environment variables:
- NVTE_RELEASE_BUILD: Controls release build behavior
- NVTE_PROJECT_BUILDING: Set to "1" during build
Note:
The script requires JAX to be installed for building.
It will raise a RuntimeError if JAX is not available.
"""
# Extensions
common_headers_dir = "common_headers"
copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir))
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Sharding Meta for xmap with CustomCall
"""Sharding utilities for Transformer Engine in JAX.
This module provides utilities for managing tensor sharding in distributed training,
including support for various parallelism strategies like data parallelism (DP),
tensor parallelism (TP), pipeline parallelism (PP), and full-sharded data
parallelism (FSDP). It includes functions for sharding constraints, mesh management,
and collective operations.
"""
import os
from contextlib import contextmanager
......@@ -181,27 +186,17 @@ def get_mesh_axis_rank(axis: str, mesh=None):
@dataclass
class MeshResource:
"""
A data container to indicate which axis in Mesh for data parallelism and
which for tensor parallelism.
Parameters
----------
dp_resource : str, default = None
The axis name in Mesh used to shard batches along.
If it is None, then data parallelism is disabled.
tp_resource : str, default = None
The axis name in Mesh used to split the hidden dimensions along.
If it is None, then tensor parallelism is disabled.
fsdp_resource : str, default = None
The axis name in Mesh used to split the batch and weights along.
If it is None, then full-sharded data parallelism is disabled.
pp_resource : str, default = None
The axis name in Mesh used to split model layers along.
If it is None, then pipeline parallelism is disabled.
cp_resource : str, default = None
The axis name in Mesh used to split sequence (context) dimensions along
in the attention. If it is None, then context parallelism is disabled.
"""A data container for managing mesh resources in distributed training.
This class defines the mapping between logical axes and physical mesh axes
for different types of parallelism in distributed training.
Attributes:
dp_resource: Axis name for data parallelism (batch sharding), default is None
tp_resource: Axis name for tensor parallelism (hidden dimension sharding), default is None
fsdp_resource: Axis name for full-sharded data parallelism, default is None
pp_resource: Axis name for pipeline parallelism (layer sharding), default is None
cp_resource: Axis name for context parallelism (sequence sharding), default is None
"""
dp_resource: str = None
......@@ -216,36 +211,55 @@ _GLOBAL_MESH_RESOURCE = MeshResource()
@contextmanager
def global_shard_guard(resource: MeshResource):
"""
A context manager to switch the global MeshResource
"""Context manager for setting global sharding configuration.
This context manager allows temporarily setting the global mesh resource
configuration for sharding operations.
Args:
resource: MeshResource instance defining the sharding configuration
"""
global _GLOBAL_MESH_RESOURCE
prev_gmr = _GLOBAL_MESH_RESOURCE
old_resources = _GLOBAL_MESH_RESOURCE
try:
_GLOBAL_MESH_RESOURCE = resource
yield
finally:
_GLOBAL_MESH_RESOURCE = prev_gmr
_GLOBAL_MESH_RESOURCE = old_resources
def global_mesh_resource() -> MeshResource:
"""
A getter of the global MeshResource
"""Get the current global mesh resource configuration.
Returns:
The current MeshResource instance
"""
return _GLOBAL_MESH_RESOURCE
def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
"""
All-Reduce (Sum) along DP and FSDP mesh axes.
"""Perform all-reduce sum operation along data parallelism and FSDP axes.
Args:
x: Input tensor to reduce
mesh: JAX mesh for distributed computation
Returns:
Reduced tensor
"""
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
"""
All-Reduce (Max) along all mesh axes.
"""Perform all-reduce max operation along all axes except pipeline parallelism.
Args:
x: Input tensor to reduce
mesh: JAX mesh for distributed computation
Returns:
Reduced tensor
"""
all_axes = get_all_mesh_axes()
for axis in all_axes:
......@@ -261,21 +275,16 @@ global_shard_resource = global_mesh_resource
class MajorShardingType(Enum):
r"""
The major sharding type to indicate sharding pattern.
.. warning::
MajorShardingType is deprecating in the near feature.
Values
----------
SINGLE:
Single process training.
DP:
Data parallel training.
TP:
Standard tensor parallel training.
DPTP:
Data and Standard tensor parallel training.
"""Enumeration of major sharding types for distributed training.
This enum defines the basic sharding patterns available for distributed
training. Note that this class is deprecated and will be removed in the future.
Values:
SINGLE: Single process training
DP: Data parallel training
TP: Standard tensor parallel training
DPTP: Data and standard tensor parallel training
"""
SINGLE = 0
......@@ -285,25 +294,19 @@ class MajorShardingType(Enum):
class ShardingType(Enum):
"""
The sharding type to indicate sharding pattern.
.. warning::
ShardingType is deprecating in the near feature.
Values
----------
SINGLE:
No sharding.
DP:
Sharding along data parallelism.
TP_COL:
Sharding along column-split tensor parallelism.
TP_ROW:
Sharding along row-split tensor parallelism.
DP_TP_COL:
Sharding along data and column-split tensor parallelism.
DP_TP_ROW:
Sharding along data and row-split tensor parallelism.
"""Enumeration of detailed sharding types for distributed training.
This enum defines specific sharding patterns for distributed training,
including combinations of data parallelism and different tensor parallelism
strategies. Note that this class is deprecated and will be removed in the future.
Values:
SINGLE: No sharding
DP: Sharding along data parallelism
TP_COL: Sharding along column-split tensor parallelism
TP_ROW: Sharding along row-split tensor parallelism
DP_TP_COL: Sharding along data and column-split tensor parallelism
DP_TP_ROW: Sharding along data and row-split tensor parallelism
"""
SINGLE = (MajorShardingType.SINGLE, "single")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment