Commit 44740c6c authored by yuguo's avatar yuguo
Browse files

Merge commit '7a9a0825' of...

Merge commit '7a9a0825' of https://github.com/NVIDIA/TransformerEngine
parents 8113d9e0 7a9a0825
......@@ -261,6 +261,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING;
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
bool const is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
size_t input_dtype_bytes = te_dtype_bytes(in_dtype);
size_t output_dtype_bytes = te_dtype_bytes(out_dtype);
......@@ -314,6 +315,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
size_t colwise_sinv_size = 0;
size_t non_group_m = flatten_axis > 1 ? product(input_dims, 1, flatten_axis) : 1;
size_t num_non_empty_groups = 0;
size_t total_rowwise_sinv_size = 0;
size_t total_colwise_sinv_size = 0;
for (size_t i = 0; i < num_groups; i++) {
size_t m_i = dim_list_host[i] * non_group_m;
// Skip for zero-size input + shiff the scale ptr
......@@ -379,6 +382,12 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
sinv_ptr += sinv_size * sinv_dtype_bytes;
colwise_sinv_ptr += colwise_sinv_size * colwise_sinv_dtype_bytes;
amax_ptr += amax_dtype_bytes;
total_rowwise_sinv_size += sinv_size;
total_colwise_sinv_size += colwise_sinv_size;
}
if (is_mxfp8_scaling) {
nvte_memset(scale_invs->untyped_data(), 0, total_rowwise_sinv_size, stream);
nvte_memset(colwise_scale_invs->untyped_data(), 0, total_colwise_sinv_size, stream);
}
QuantizationConfigWrapper quant_config;
......
......@@ -8,7 +8,7 @@ architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations.
"""
import warnings
from typing import Tuple, Sequence
from functools import partial
import jax
......@@ -19,9 +19,20 @@ from .quantize import (
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
TensorUsage,
)
DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global DENSE_BATCH_FIRST_WARNING_ISSUED
if not DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
DENSE_BATCH_FIRST_WARNING_ISSUED = True
def dense(
x: jnp.ndarray,
kernel: jnp.ndarray,
......@@ -29,6 +40,7 @@ def dense(
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set,
):
"""Perform dense layer transformation with optional quantization.
......@@ -42,25 +54,28 @@ def dense(
kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_first: Assume that X is batched in the first dimension.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Transformed output tensor
"""
# Remove when tex.quantize() can handle quantizer=None
if quantizer_set == noop_quantizer_set:
if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot():
x = with_sharding_constraint_by_logical_axes(x, input_axes)
output = tex.gemm(x, kernel, contracting_dims)
output = tex.gemm(x, kernel, contracting_dims=contracting_dims)
if bias is not None:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
else:
output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set)
output = _dense(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set):
"""Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support
......@@ -74,81 +89,126 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
Returns:
Transformed output tensor
"""
output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
)
return output
def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
def _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
):
"""Forward pass rule for dense layer transformation.
Returns:
Tuple of (output, context) for backward pass
"""
x_contracting_dims, k_contracting_dims = contracting_dims
x_contracting_dims, k_contracting_dims = map(
tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims
)
# Check supported input layout
x_is_transposed = x.ndim - 1 not in x_contracting_dims
k_is_transposed = kernel.ndim - 1 in k_contracting_dims
assert (
not x_is_transposed and not k_is_transposed
), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."
# Determine X batch dimension
# - If `batch_first=True` -> (batch, leading..., contracting...)
# - Otherwise -> (leading..., batch, contracting...)
# NOTE: Always assume a single batch dimension
x_bdim = None
num_cdims = len(x_contracting_dims)
if x.ndim >= num_cdims + 2:
# Assume X is batched if it has at least +2 dimensions more than the number of contracting
# dimensions.
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `dense()` layer implementation does not officially support sequence-first "
"inputs and may produce incorrect results when `batch_first=False`. Use "
"sequence-first inputs at your own discretion.",
)
x_bdim = 0 if batch_first else x.ndim - num_cdims - 1
flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x)
casted_x = tex.quantize(
x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True
)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel
kernel,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.kernel,
noop_scaled_tensor=True,
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# GEMM NN
use_bias = bias is not None
output = tex.gemm(
casted_x.get_rowwise_tensor(),
casted_kernel.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
)
use_bias = bias is not None
if use_bias:
if use_bias and tex.gemm_uses_jax_dot():
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
ctx = (
casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None,
casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None,
casted_x.get_tensor(usage=TensorUsage.LHS_TRANS),
casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS),
x.shape,
kernel.shape,
use_bias,
quantizer_set,
flatten_axis_k,
x_bdim,
)
return output, ctx
def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad
contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad
): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation.
Returns:
Tuple of gradients with respect to inputs
"""
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
(
colwise_casted_x,
rowwise_casted_kernel,
casted_x_lhs,
casted_kernel_rhs,
x_shape,
kernel_shape,
use_bias,
quantizer_set,
flatten_axis_k,
x_bdim,
) = ctx
fwd_x_contracting_dims, fwd_k_contracting_dims = map(
tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
)
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad
grad,
is_dbias=use_bias,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
)
# GEMM NT
......@@ -161,9 +221,10 @@ def _dense_bwd_rule(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
)
dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel,
(g_contracting_dim, k_contracting_dim),
casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim),
batched_dims=((x_bdim,), ()),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
......@@ -174,7 +235,10 @@ def _dense_bwd_rule(
)
wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim)
casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......@@ -287,7 +351,6 @@ def _grouped_dense_fwd_rule(
"and k_contracting_dims=(1,) for now, "
f"got {x_contracting_dims=} and {k_contracting_dims=}"
)
k_contracting_dims = (0,)
casted_x = tex.grouped_quantize(
x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x
......@@ -300,11 +363,10 @@ def _grouped_dense_fwd_rule(
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
# rowwise_casted_x.original_shape == (M, K)
# colwise_casted_kernel.original_shape == (G, N, K)
grouped_gemm_x = casted_x.get_rowwise_tensor()
grouped_gemm_kernel = casted_kernel.get_colwise_tensor()
# TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()?
ctx_x = casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None
ctx_kernel = casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None
grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS)
grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)
ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS)
output = tex.grouped_gemm(
grouped_gemm_x,
......@@ -382,17 +444,17 @@ def _grouped_dense_bwd_rule(
g_contracting_dim = (1,)
k_contracting_dim = (2,)
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_grad = casted_grad.get_rowwise_tensor()
dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS)
dgrad_kernel_T = ctx_kernel
# We need to use g_contracting_dim = (0,) and x_contracting_dim = (1,) to make it work
# We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work
# after the extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (0,)
x_contracting_dim = (0,)
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_x_T = ctx_x
wgrad_grad = casted_grad.get_colwise_tensor()
wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS)
dgrad = tex.grouped_gemm(
dgrad_grad,
......
......@@ -6,7 +6,7 @@ Wrapper module for Transformer related layers with FP8 support.
"""
from functools import reduce
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
import numpy as np
import jax.numpy as jnp
......@@ -15,12 +15,12 @@ from jax import lax
from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name
from ..dense import dense
from ..dense import dense, _issue_batch_first_warning as _dense_warning
from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning
from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning
from ..activation import activation
from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
......@@ -35,8 +35,8 @@ from ..sharding import get_non_contracting_logical_axes
PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
PrecisionLike = Union[
None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
......@@ -441,6 +441,12 @@ class DenseGeneral(TransformerEngineBase):
input_axes: Tuple[str, ...] = ()
def __post_init__(self):
if self.transpose_batch_sequence:
_dense_warning(
"TE/JAX DenseGeneral() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
......@@ -657,6 +663,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
depth_scaling: float = None
def __post_init__(self):
if self.transpose_batch_sequence:
_ln_dense_warning(
"TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first "
"inputs and may produce incorrect results when `transpose_batch_sequence=True`. "
"Use sequence-first inputs at your own discretion."
)
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0,
......@@ -967,6 +979,12 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes: Tuple[str, ...] = None
def __post_init__(self):
if self.transpose_batch_sequence:
_ln_mlp_warning(
"TE/JAX LayerNormMLP() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
......
......@@ -180,8 +180,9 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
# (b, h, q, k): Last two axes are always replicated
attn_weights = with_sharding_constraint_by_logical_axes(
attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)
attn_weights, (BATCH_AXES, HEAD_AXES, None, None)
)
# When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
......
......@@ -9,6 +9,7 @@ architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints.
"""
import warnings
from functools import partial
from typing import Tuple
......@@ -21,9 +22,20 @@ from .quantize import (
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
TensorUsage,
)
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_dense(
x: jnp.ndarray,
kernel: jnp.ndarray,
......@@ -36,6 +48,7 @@ def layernorm_dense(
layernorm_input_axes: Tuple[str, ...] = None,
dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation.
......@@ -56,6 +69,7 @@ def layernorm_dense(
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_set: Set of quantizers for different tensor types
Returns:
......@@ -79,6 +93,7 @@ def layernorm_dense(
layernorm_input_axes,
dot_input_axes,
kernel_axes,
batch_first,
quantizer_set,
)
return output
......@@ -93,6 +108,7 @@ def layernorm_dense(
8,
9,
10,
11,
),
)
def _layernorm_dense(
......@@ -107,6 +123,7 @@ def _layernorm_dense(
layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...],
batch_first: bool,
quantizer_set,
):
"""Internal implementation of layernorm_dense with custom VJP.
......@@ -126,6 +143,7 @@ def _layernorm_dense(
epsilon: Small constant for numerical stability
layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding
batch_first: Assume that X is batched in the first dimension.
quantizer_set: Set of quantizers
Returns:
......@@ -143,6 +161,7 @@ def _layernorm_dense(
layernorm_input_axes,
dot_input_axes,
kernel_axes,
batch_first,
quantizer_set,
)
return output
......@@ -160,6 +179,7 @@ def _layernorm_dense_fwd_rule(
layernorm_input_axes,
dot_input_axes,
kernel_axes,
batch_first,
quantizer_set,
):
"""Forward pass rule for layernorm_dense.
......@@ -177,6 +197,17 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_dense()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
casted_ln_out, mu, rsigma = tex.normalization_fwd(
......@@ -186,31 +217,37 @@ def _layernorm_dense_fwd_rule(
zero_centered_gamma,
epsilon,
norm_type,
quantizer_set.x,
quantizer=quantizer_set.x,
noop_scaled_tensor=True,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# Kernel in (hidden_in, hidden_out...)
flatten_axis = 1 - len(kernel.shape)
casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel)
casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...)
use_bias = bias is not None
output = tex.gemm(
casted_ln_out.get_rowwise_tensor(),
casted_kernel.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
)
use_bias = bias is not None
if use_bias:
if use_bias and tex.gemm_uses_jax_dot():
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
ctx = (
casted_ln_out.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None,
casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None,
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel.get_tensor(TensorUsage.RHS_TRANS),
x.shape,
kernel.shape,
mu,
......@@ -223,6 +260,7 @@ def _layernorm_dense_fwd_rule(
use_bias,
quantizer_set,
flatten_axis,
x_bdim,
)
return output, ctx
......@@ -235,6 +273,7 @@ def _layernorm_dense_bwd_rule(
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
kernel_axes,
batch_first, # pylint: disable=unused-argument
ctx,
grad,
):
......@@ -250,8 +289,8 @@ def _layernorm_dense_bwd_rule(
Tuple of gradients for all input parameters
"""
(
colwise_casted_ln_out,
rowwise_casted_kernel,
casted_ln_out,
casted_kernel,
x_shape,
kernel_shape,
mu,
......@@ -264,10 +303,15 @@ def _layernorm_dense_bwd_rule(
use_bias,
quantizer_set,
flatten_axis,
x_bdim,
) = ctx
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad
grad,
is_dbias=use_bias,
flatten_axis=flatten_axis,
quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
......@@ -281,9 +325,10 @@ def _layernorm_dense_bwd_rule(
# NT GEMM
dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel,
(g_constracting_dim, k_constracting_dim),
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel,
contracting_dims=(g_constracting_dim, k_constracting_dim),
batched_dims=((x_bdim,), ()),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
......@@ -294,9 +339,10 @@ def _layernorm_dense_bwd_rule(
# TN GEMM
wgrad = tex.gemm(
colwise_casted_ln_out,
casted_grad.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim),
casted_ln_out,
casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_constracting_dim, g_constracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
......@@ -13,6 +13,7 @@ The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
import warnings
from typing import List, Tuple, Sequence, Union, Callable
from functools import partial
......@@ -22,10 +23,25 @@ from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex
from .layernorm import canonicalize_norm_type
from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set
from .quantize import (
with_sharding_constraint_by_logical_axes,
QuantizerSet,
noop_quantizer_set,
TensorUsage,
)
from .sharding import get_non_contracting_logical_axes
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
......@@ -43,6 +59,7 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
batch_first: bool = True,
quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
) -> jnp.ndarray:
"""Apply layer normalization followed by MLP block.
......@@ -74,6 +91,7 @@ def layernorm_mlp(
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns:
......@@ -119,12 +137,13 @@ def layernorm_mlp(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
batch_first,
quantizer_sets,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
def _layernorm_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
......@@ -144,6 +163,7 @@ def _layernorm_mlp(
ffn1_ckpt_name: str,
ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
batch_first: bool,
quantizer_sets,
):
"""Internal implementation of layernorm_mlp with custom VJP.
......@@ -169,6 +189,7 @@ def _layernorm_mlp(
ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s)
batch_first: Assume that X is batched in the first dimension.
quantizer_sets: Tuple of quantizer sets
Returns:
......@@ -193,6 +214,7 @@ def _layernorm_mlp(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
batch_first,
quantizer_sets,
)
return output
......@@ -217,6 +239,7 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
batch_first,
quantizer_sets,
):
"""Forward pass rule for layernorm_mlp.
......@@ -249,6 +272,17 @@ def _layernorm_mlp_fwd_rule(
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_mlp()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
use_bias_1 = bias_1 is not None
use_bias_2 = bias_1 is not None
......@@ -262,17 +296,23 @@ def _layernorm_mlp_fwd_rule(
epsilon,
norm_type,
quantizer=ffn1_quantizer_set.x,
noop_scaled_tensor=True,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel)
casted_kernel_1 = tex.quantize(
kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True
)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = tex.gemm(
casted_ln_out.get_rowwise_tensor(),
casted_kernel_1.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
)
if dot_1_input_axes is not None and kernel_1_axes is not None:
......@@ -282,7 +322,7 @@ def _layernorm_mlp_fwd_rule(
)
dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
if use_bias_1:
if use_bias_1 and tex.gemm_uses_jax_dot():
bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
......@@ -290,21 +330,28 @@ def _layernorm_mlp_fwd_rule(
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
# (batch..., hidden_in) -> (batch..., hidden)
casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x)
casted_act_out = tex.act_lu(
dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True
)
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel)
casted_kernel_2 = tex.quantize(
kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True
)
# NN GEMM
# (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = tex.gemm(
casted_act_out.get_rowwise_tensor(),
casted_kernel_2.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
)
if use_bias_2:
if use_bias_2 and tex.gemm_uses_jax_dot():
bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
......@@ -317,11 +364,11 @@ def _layernorm_mlp_fwd_rule(
rsigma,
gamma,
beta,
casted_ln_out.get_colwise_tensor(),
casted_kernel_1.get_rowwise_tensor(),
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
dot_1_output,
casted_act_out.get_colwise_tensor(),
casted_kernel_2.get_rowwise_tensor(),
casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
x_contracting_dims,
k_contracting_dims,
kernel_1.shape,
......@@ -329,6 +376,7 @@ def _layernorm_mlp_fwd_rule(
use_bias_1,
use_bias_2,
quantizer_sets,
x_bdim,
)
return dot_2_output, ctx
......@@ -346,6 +394,7 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
batch_first,
ctx,
grad,
):
......@@ -362,18 +411,18 @@ def _layernorm_mlp_bwd_rule(
Returns:
Tuple of gradients for all input parameters
"""
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first
(
x,
mu,
rsigma,
gamma,
beta,
colwise_casted_ln_out,
rowwise_casted_kernel_1,
casted_ln_out,
casted_kernel_1,
dot_1_output,
colwise_casted_act_out,
rowwise_casted_kernel_2,
casted_act_out,
casted_kernel_2,
x_contracting_dims_in_fwd,
k_contracting_dims_in_fwd,
kernel_1_shape,
......@@ -381,6 +430,7 @@ def _layernorm_mlp_bwd_rule(
use_bias_1,
use_bias_2,
quantizer_sets,
x_bdim,
) = ctx
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
......@@ -389,7 +439,7 @@ def _layernorm_mlp_bwd_rule(
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
casted_grad, dbias_2 = tex.quantize_dbias(
grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad
grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......@@ -404,9 +454,10 @@ def _layernorm_mlp_bwd_rule(
# NT GEMM
# (batch..., hidden_out) x (hidden_in, hidden_out)
dgrad_2 = tex.gemm(
casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel_2,
(g_contracting_dims_2, k_contracting_dims_2),
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2,
contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
batched_dims=((x_bdim,), ()),
)
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
......@@ -418,9 +469,10 @@ def _layernorm_mlp_bwd_rule(
# TN GEMM
# (hidden, batch...,) x (hidden, batch...)
wgrad_2 = tex.gemm(
colwise_casted_act_out,
casted_grad.get_colwise_tensor(),
(x_contracting_dims, g_contracting_dims),
casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims),
batched_dims=((x_bdim,), (x_bdim,)),
)
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
......@@ -430,10 +482,11 @@ def _layernorm_mlp_bwd_rule(
activation_type=activation_type,
is_dbias=use_bias_1,
quantizer=ffn2_quantizer_set.dgrad,
noop_scaled_tensor=True,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim
dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
g_contracting_dims_1 = tuple(
range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
)
......@@ -444,9 +497,10 @@ def _layernorm_mlp_bwd_rule(
# NT GEMM
dgrad_1 = tex.gemm(
casted_dact_out.get_rowwise_tensor(),
rowwise_casted_kernel_1,
(g_contracting_dims_1, k_contracting_dims_1),
casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
batched_dims=((x_bdim,), ()),
)
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
......@@ -454,9 +508,10 @@ def _layernorm_mlp_bwd_rule(
# TN GEMM
# (hidden, batch...) x (hidden, batch...)
wgrad_1 = tex.gemm(
colwise_casted_ln_out,
casted_dact_out.get_colwise_tensor(),
(x_contracting_dims, g_contracting_dims),
casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims),
batched_dims=((x_bdim,), (x_bdim,)),
)
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
......@@ -15,3 +15,4 @@ from .dequantizer import *
from .scaling_modes import *
from .metadata import *
from .helper import *
from .device_utils import *
......@@ -36,6 +36,22 @@ class Dequantizer(ABC):
"""Dequantizing given tensor to higher precision."""
@dataclass
class NoopDequantizer(Dequantizer):
"""No-op Dequantizer Class"""
@staticmethod
def _dequantize_func(data, *args, **kwargs):
"""A no-op dequantize function that returns the data without any changes."""
del args, kwargs
return data
@staticmethod
def dequantize(scaled_tensor):
"""A no-op dequantize function that simply returns the data array in the ScaledTensor."""
return scaled_tensor.data
class TensorScaleDequantizer(Dequantizer):
"""
TensorScaling Dequantizer Class
......@@ -105,9 +121,6 @@ class BlockScaleDequantizer(Dequantizer):
scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
)
scale_inv = jax.lax.slice(
scale_inv, [0] * len(scale_shape), scale_shape
) # slice out the padding
data = data.reshape(
*data_shape[: flatten_axis - 1],
......@@ -152,6 +165,7 @@ ScalingModeToDequantizerMap = {
ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer,
ScalingMode.NO_SCALING: NoopDequantizer,
}
......@@ -194,28 +208,38 @@ def _grouped_dequantize(grouped_scaled_tensor):
f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to"
f" {data_i.size}"
)
scale_shape_i = scaling_mode.get_scale_shape(
padded_scale_shape_i = scaling_mode.get_scale_shape(
data_shape_i,
grouped_scaled_tensor.is_colwise,
is_padded=True,
flatten_axis=flatten_axis,
)
scale_shape_i_size = math.prod(scale_shape_i)
scale_inv_i = scale_inv[scale_inv_ptr : scale_inv_ptr + scale_shape_i_size]
unpadded_scale_shape_i = scaling_mode.get_scale_shape(
data_shape_i,
grouped_scaled_tensor.is_colwise,
is_padded=False,
flatten_axis=flatten_axis,
)
scale_inv_i = scale_inv[
scale_inv_ptr : scale_inv_ptr + math.prod(padded_scale_shape_i)
].reshape(padded_scale_shape_i)
scale_inv_i = jax.lax.slice(
scale_inv_i, [0] * len(unpadded_scale_shape_i), unpadded_scale_shape_i
)
dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode)
if len(data_i) == 0:
out_i = []
else:
out_i = dequantizer_type._dequantize_func(
data_i.reshape(data_shape_i),
scale_inv_i.reshape(scale_shape_i),
scale_inv_i,
grouped_scaled_tensor.dq_dtype,
scaling_mode=grouped_scaled_tensor.scaling_mode,
is_colwise=grouped_scaled_tensor.is_colwise,
flatten_axis=grouped_scaled_tensor.flatten_axis,
)
output.append(out_i)
scale_inv_ptr += scale_shape_i_size
scale_inv_ptr += math.prod(padded_scale_shape_i)
return output
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Device utility functions for JAX quantization.
This module provides utility functions for checking device capabilities and compatibility
for quantization operations in JAX.
"""
import functools
import transformer_engine_jax
__all__ = [
"get_device_compute_capability",
"is_fp8_gemm_with_all_layouts_supported",
]
@functools.lru_cache(maxsize=None)
def get_device_compute_capability(gpu_id: int = 0) -> int:
"""
Get the compute capability of the device.
"""
return transformer_engine_jax.get_device_compute_capability(gpu_id)
@functools.lru_cache(maxsize=None)
def is_fp8_gemm_with_all_layouts_supported() -> bool:
"""Return True if using Blackwell architecture, False otherwise."""
compute_capability = get_device_compute_capability()
return 100 <= compute_capability < 120
......@@ -9,23 +9,21 @@ in JAX, including support for different scaling modes and datatypes.
"""
from contextlib import contextmanager
from enum import Enum
from typing import Optional, Tuple, Dict, Union
from typing import Optional, Tuple, Dict, Union, Sequence
from functools import reduce
import operator
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformer_engine_jax import DType
from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import (
get_cuda_version,
get_device_compute_capability,
)
from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version
from transformer_engine.common import recipe
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
from .scaling_modes import ScalingMode
from .. import cpp_extensions as tex
from .device_utils import get_device_compute_capability
__all__ = [
"QuantizeConfig",
......@@ -33,6 +31,8 @@ __all__ = [
"is_fp8_available",
"update_collections",
"get_delayed_scaling",
"apply_padding_to_scale_inv",
"remove_padding_from_scale_inv",
"NVTE_FP8_COLLECTION_NAME",
]
......@@ -203,7 +203,7 @@ class QuantizeConfig:
FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass
FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients
FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients
IF_QUANTIZE_2X: Whether 2x quantization is enabled
INFERENCE_MODE: Whether to enable optimization for inference
SCALING_MODE: Scaling mode
AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling
AMAX_COMPUTE_ALGO: Algorithm for AMAX computation
......@@ -218,7 +218,7 @@ class QuantizeConfig:
FP8_2X_ACC_FPROP: bool = False
FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False
IF_QUANTIZE_2X: bool = False
INFERENCE_MODE: bool = False
SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING
# DelayedScaling
......@@ -246,7 +246,6 @@ class QuantizeConfig:
cls.FP8_FORMAT = fp8_recipe.fp8_format
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = _get_scaling_mode(fp8_recipe)
cls.IF_QUANTIZE_2X = True
@classmethod
def finalize(cls) -> None:
......@@ -260,7 +259,7 @@ class QuantizeConfig:
cls.FP8_2X_ACC_DGRAD = False
cls.FP8_2X_ACC_WGRAD = False
cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.IF_QUANTIZE_2X = False
cls.INFERENCE_MODE = False
# DelayedScaling
cls.AMAX_HISTORY_LEN = 1024
cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
......@@ -476,4 +475,115 @@ def update_collections(new: Collection, original: Collection) -> Collection:
return new_coll
def remove_padding_from_scale_inv(
scale_inv: jax.Array,
scaling_mode: ScalingMode,
data_shape: Sequence[int],
is_colwise: bool = False,
flatten_axis: int = -1,
):
"""
Slice padding out of padded inverse scale factors.
Args:
scale_inv: Inverse scale factor.
data_shape: Shape of the quantized data the inverse scale belongs to.
scaling_mode: ScalingMode representing the quantization method.
is_colwise: Whether the data was quantized column-wise.
flatten_axis: The axis along with the data could be flattened to 2D.
Returns:
Inverse scale factor without padding.
"""
# Get expected unpadded scale shape and check if inverse scale already matches
unpadded_scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
)
if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == unpadded_scale_shape:
return scale_inv
# Get the padded scale shape and make sure inverse scale matches
padded_scale_shape = scaling_mode.get_scale_shape(
data_shape,
is_colwise=is_colwise,
is_padded=True,
flatten_axis=flatten_axis,
)
assert scale_inv.shape == padded_scale_shape, (
f"Padded inverse scale factor has wrong shape, expected {padded_scale_shape} but got "
f"{scale_inv.shape} instead."
)
# Reshape scale inverse to 2D in two stages to preserve the flatten axis
padded_scale_shape_2d = (
reduce(operator.mul, padded_scale_shape[:flatten_axis]),
reduce(operator.mul, padded_scale_shape[flatten_axis:]),
)
scale_inv_2d = jnp.reshape(
jnp.reshape(scale_inv, (padded_scale_shape_2d[0], *scale_inv.shape[flatten_axis:])),
padded_scale_shape_2d,
)
# Slice reshaped 2D scale inverse using collapsed 2D unpadded_scale_shape
unpadded_scale_shape_2d = (
reduce(operator.mul, unpadded_scale_shape[:flatten_axis]),
reduce(operator.mul, unpadded_scale_shape[flatten_axis:]),
)
scale_inv_2d_unpadded = jnp.asarray(
scale_inv_2d[: unpadded_scale_shape_2d[0], : unpadded_scale_shape_2d[1]]
)
# Reshape 2D scale inverse back in two stages in order to preserve the flatten axis
scale_inv_unpadded = jnp.reshape(
jnp.reshape(
scale_inv_2d_unpadded,
(*unpadded_scale_shape[:flatten_axis], scale_inv_2d_unpadded.shape[1]),
),
unpadded_scale_shape,
)
return scale_inv_unpadded
def apply_padding_to_scale_inv(
scale_inv: jax.Array,
scaling_mode: ScalingMode,
data_shape: Sequence[int],
is_colwise: bool = False,
flatten_axis: int = -1,
):
"""
Pad the scale inverse with zeros to match the necessary padded shape for this scaling
mode.
Args:
scale_inv: Inverse scale factor.
data_shape: Shape of the quantized data the inverse scale belongs to.
scaling_mode: ScalingMode representing the quantization method.
is_colwise: Whether the data was quantized column-wise.
flatten_axis: The axis along with the data could be flattened to 2D.
Returns:
Padded inverse scale factor.
"""
# Get the expected padded scale shape and check if inverse scale already matches
padded_scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise=is_colwise, is_padded=True, flatten_axis=flatten_axis
)
if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == padded_scale_shape:
return scale_inv
# Get the expected unpadded scale shape and make sure inverse scales match
unpadded_scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
)
assert scale_inv.shape == unpadded_scale_shape, (
f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got "
f"{scale_inv.shape}."
)
# Pad the scales with the lowest representable value (2^-127) and return
pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape))
return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127)
NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME
......@@ -23,6 +23,7 @@ from .helper import (
QuantizeConfig,
AmaxComputeAlgo,
)
from .device_utils import is_fp8_gemm_with_all_layouts_supported
__all__ = [
"QuantizeLayout",
......@@ -607,9 +608,10 @@ class GroupedQuantizer(Quantizer):
def __post_init__(self):
if self.quantizers[0] is None:
self.quantizers = QuantizerFactory.create(
quantizers = QuantizerFactory.create(
self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout
)
self.quantizers = (quantizers,) if not isinstance(quantizers, tuple) else quantizers
self.data_layout = self.quantizers[0].data_layout
def _create_grouped_tensor_from_tensor_list(
......@@ -841,8 +843,10 @@ class QuantizerFactory:
if is_2x2x:
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
else:
q_layout_x = QuantizeLayout.ROWWISE
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE
if scaling_mode.is_1d_block_scaling():
q_layout_kernel = QuantizeLayout.COLWISE
if QuantizeConfig.INFERENCE_MODE:
q_layout_dgrad = None
if "quantize_meta_set" in kwargs:
......@@ -898,7 +902,15 @@ class QuantizerFactory:
scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE
bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE
is_2x2x = is_2x2x or QuantizeConfig.IF_QUANTIZE_2X
if is_2x2x is None:
if scaling_mode.is_1d_block_scaling():
is_2x2x = True
elif scaling_mode.is_tensor_scaling():
is_2x2x = not is_fp8_gemm_with_all_layouts_supported()
else: # NO_SCALING ignores is_2x2x for now
is_2x2x = False
is_inference_mode = QuantizeConfig.INFERENCE_MODE
assert not is_inference_mode, "Inference mode is not supported yet!"
q_set = []
for _ in range(n_quantizer_sets):
......@@ -911,4 +923,4 @@ class QuantizerFactory:
return q_set[0] if len(q_set) == 1 else tuple(q_set)
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING)
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING, is_2x2x=False)
......@@ -13,18 +13,52 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Tuple, Dict
from functools import reduce
from functools import reduce, lru_cache
import operator
import numpy as np
from jax.experimental.custom_partitioning import CompoundFactor
from jax.experimental.custom_partitioning import BATCHING
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp
from transformer_engine_jax import JAXX_Scaling_Mode
from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout
from .device_utils import is_fp8_gemm_with_all_layouts_supported
__all__ = ["QuantizeShardyRules", "ScalingMode"]
__all__ = [
"QuantizeShardyRules",
"ScalingMode",
"TensorUsage",
]
class TensorUsage(Enum):
"""Enum indicating tensor usage in GEMM operations.
Given a GEMM operation: C = A * B in which A and B can be in the normal or transposed form.
The tensor usage can be:
- LHS: A is in the normal form
- LHS_TRANS: A is in the transposed form
- RHS: B is in the normal form
- RHS_TRANS: B is in the transposed form
The tensor usage is used in the ScaledTensor.get_tensor() method.
"""
# LHS: Left-hand side, RHS: Right-hand side
# LHS_TRANS: Left-hand side transposed, RHS_TRANS: Right-hand side transposed
LHS = 0
LHS_TRANS = 1
RHS = 2
RHS_TRANS = 3
def __eq__(self, other):
if not isinstance(other, TensorUsage):
return False
return self.value == other.value
def __hash__(self):
return hash(self.value)
def DIVUP(a, b):
......@@ -104,6 +138,18 @@ class ScalingModeMetadataImpl(ABC):
The shape for scale tensors
"""
@lru_cache(maxsize=4)
@abstractmethod
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
@abstractmethod
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
......@@ -157,6 +203,23 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (0,)
return (1,)
@lru_cache(maxsize=4)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
if is_fp8_gemm_with_all_layouts_supported():
return QuantizeLayout.ROWWISE
if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS):
return QuantizeLayout.ROWWISE
return QuantizeLayout.COLWISE
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
......@@ -189,8 +252,9 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
The Shardy rules for the scaling mode
"""
del flatten_axis
input_spec = tuple(f"x{i}" for i in range(input_rank))
return QuantizeShardyRules(input_spec, (unique_var,), (unique_var,), {})
input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank))
scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl):
......@@ -321,6 +385,27 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (*first_dim_scale_shape, *last_dim_scale_shape)
@lru_cache(maxsize=4)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
# If we need to support 1x1x for inference in the future
# if QuantizeConfig.INFERENCE_MODE:
# assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!")
# if usage == TensorUsage.LHS:
# return QuantizeLayout.ROWWISE
# return QuantizeLayout.COLWISE
if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS):
return QuantizeLayout.ROWWISE
return QuantizeLayout.COLWISE
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
......@@ -404,31 +489,41 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
Returns:
The Shardy rules for the scaling mode
"""
input_spec = [f"x{i}" for i in range(input_rank)]
# We have to use two different factors in the two CompoundFactors because of Shardy
# verifier requirements, even though they are the same.
rowwise_var = unique_var
colwise_var = f"{unique_var}_"
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise")
input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise")
# The rowwise and colwise scale tensors should be sharded the same way as the input.
# However, we need to adjust the dimensions where the block scaling factor applies.
rowwise = input_spec.copy()
rowwise[-1] = rowwise_var
colwise = input_spec.copy()
colwise[flatten_axis - 1] = colwise_var
# This implementation needs to be updated for different block dims.
assert self._block_dims == (1, 32)
del flatten_axis
input_spec = [f"{unique_var}{i}" for i in range(input_rank)]
rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)]
colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)]
# NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors.
# Unfortunately, because Shardy rules are applied to the inner primitive, the
# only way to preserve the relationship is to lower unpadded scales to the
# underlying custom call and pad them in C++. Until that's implemented, the
# Shardy rules for block scales have to be completely disconnected from the
# Shardy rules for the tensor they belong to.
# # We have to use two different factors in the two CompoundFactors because of Shardy
# # verifier requirements, even though they are the same.
# rowwise_var = unique_var
# colwise_var = f"{unique_var}_"
# input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise")
# input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise")
# # The rowwise and colwise scale tensors should be sharded the same way as the input.
# # However, we need to adjust the dimensions where the block scaling factor applies.
# rowwise = input_spec.copy()
# rowwise[-1] = rowwise_var
# colwise = input_spec.copy()
# colwise[flatten_axis - 1] = colwise_var
# # This implementation needs to be updated for different block dims.
# assert self._block_dims == (1, 32)
return QuantizeShardyRules(
tuple(input_spec),
tuple(rowwise),
tuple(colwise),
{"block_size_rowwise": 32, "block_size_colwise": 32},
{}, # {"block_size_rowwise": 32, "block_size_colwise": 32},
)
......@@ -506,6 +601,17 @@ class ScalingMode(Enum):
"""
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
return self._get_impl().get_quantize_layout(usage)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis=-1
) -> Tuple[Tuple[str]]:
......
......@@ -17,13 +17,14 @@ from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode
from .scaling_modes import ScalingMode, TensorUsage
from .dequantizer import ScalingModeToDequantizerMap
from ..sharding import (
with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
)
__all__ = [
"TensorUsage",
"ScaledTensor",
"ScaledTensor1x",
"ScaledTensor2x",
......@@ -55,6 +56,11 @@ class ScaledTensor(ABC):
"""
return cls(*children, *aux_data)
@property
@abstractmethod
def ndim(self):
"""Number of dimensions of the underlying quantized array."""
@abstractmethod
def dequantize(self):
"""Dequantizes the tensor back to its original precision.
......@@ -64,25 +70,15 @@ class ScaledTensor(ABC):
"""
@abstractmethod
def get_rowwise_tensor(self):
"""Returns the row-wise component of the tensor.
def get_tensor(self, usage: TensorUsage):
"""Returns the appropriate tensor based on the tensor usage and the scaling mode.
If the tensor usage is not valid for the scaling mode, an error is raised.
Returns:
The row-wise tensor component
Raises:
ValueError: If called on a tensor that doesn't support row-wise access
"""
@abstractmethod
def get_colwise_tensor(self):
"""Returns the column-wise component of the tensor.
Args:
usage: The usage of the tensor
Returns:
The column-wise tensor component
Raises:
ValueError: If called on a tensor that doesn't support column-wise access
The tensor based on the usage
"""
@abstractmethod
......@@ -136,24 +132,18 @@ class ScaledTensor1x(ScaledTensor):
0 < self.flatten_axis < len(self.data.shape)
), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}"
expected_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis
)
expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis
)
if self.scale_inv.shape != expected_scale_shape:
assert self.scale_inv.shape == expected_unpadded_scale_shape, (
f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded"
f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got"
f" {self.scale_inv.shape}"
)
pad_width = tuple(
(0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape)
if self.scaling_mode == ScalingMode.NO_SCALING:
self.scale_inv = jnp.empty((0,), dtype=jnp.float32)
else:
unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape,
is_colwise=self.is_colwise,
is_padded=False,
flatten_axis=self.flatten_axis,
)
# This actually pad scale_inv with nan, should we pad it with 127 directly instead?
self.scale_inv = jnp.pad(
self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0
assert self.scale_inv.shape == unpadded_scale_shape, (
"Unpadded inverse scale factor has wrong shape, expected"
f" {unpadded_scale_shape} but got {self.scale_inv.shape}."
)
def tree_flatten(self):
......@@ -173,6 +163,10 @@ class ScaledTensor1x(ScaledTensor):
)
return (children, aux_data)
@property
def ndim(self):
return self.data.ndim
def dequantize(self):
"""Dequantizes the tensor using the stored dequantization function.
......@@ -181,33 +175,19 @@ class ScaledTensor1x(ScaledTensor):
"""
return self._dq_func(self)
def get_rowwise_tensor(self):
"""Returns the tensor if it's row-wise quantized.
Returns:
The row-wise tensor
Raises:
ValueError: If called on a column-wise quantized tensor
"""
if not self.is_colwise:
return self
raise ValueError("Calling get_rowwise_tensor() from a colwise ScaledTensor1x!")
def get_tensor(self, usage: TensorUsage):
"""Returns the tensor based on the tensor usage."""
q_layout = self.scaling_mode.get_quantize_layout(usage)
colwise_usage_valid = q_layout == QuantizeLayout.COLWISE and self.is_colwise
rowwise_usage_valid = q_layout == QuantizeLayout.ROWWISE and not self.is_colwise
def get_colwise_tensor(self):
"""Returns the tensor if it's column-wise quantized.
Returns:
The column-wise tensor
Raises:
ValueError: If called on a row-wise quantized tensor
"""
if self.is_colwise:
if colwise_usage_valid or rowwise_usage_valid:
return self
raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!")
raise ValueError(
f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
f" self.is_colwise={self.is_colwise}!"
)
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
......@@ -370,6 +350,11 @@ class ScaledTensor2x(ScaledTensor):
aux_data = ()
return (children, aux_data)
@property
def ndim(self):
"""Number of dimensions of the underlying row-wise tensor."""
return self.rowwise_tensor.ndim
def dequantize(self):
"""Dequantizes the tensor using the row-wise component's dequantization.
......@@ -378,22 +363,22 @@ class ScaledTensor2x(ScaledTensor):
"""
return self.rowwise_tensor.dequantize()
def get_rowwise_tensor(self):
"""Returns the row-wise quantized component.
def get_tensor(self, usage: TensorUsage):
"""Returns the tensor based on the tensor usage."""
q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage)
q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage)
Returns:
The row-wise tensor component
"""
if q_layout_rowwise == QuantizeLayout.ROWWISE:
return self.rowwise_tensor
def get_colwise_tensor(self):
"""Returns the column-wise quantized component.
Returns:
The column-wise tensor component
"""
if q_layout_colwise == QuantizeLayout.COLWISE:
return self.colwise_tensor
raise ValueError(
f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
f" q_layout_rowwise={q_layout_rowwise} and q_layout_colwise={q_layout_colwise}!"
)
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
......
......@@ -14,6 +14,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
import warnings
from jax.interpreters import pxla
import jax
import jax.numpy as jnp
......@@ -117,7 +118,9 @@ def with_sharding_constraint_by_logical_axes(
x: jnp.array, logical_axis_names: Optional[tuple | list]
):
"""
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
A wrapper function to flax.linen.with_logical_constraint.
DEPRECATED USE CASE: If no Flax logical axis rules are available, this function falls back to jax.lax.with_sharding_constraint using a hardcoded logical axis rule table from TE rules, such as BATCH_AXES. This functionality will be removed in the future.
If logical_axis_names = None, this means no sharding constraint is applied.
......@@ -133,6 +136,28 @@ def with_sharding_constraint_by_logical_axes(
if not logical_axis_names:
return x
try:
# Check if Flax logical axis rules are available, if so use them
import flax
flax_rules = flax.linen.get_logical_axis_rules()
if len(flax_rules) > 0:
return flax.linen.with_logical_constraint(
x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT
)
except ImportError:
pass
warnings.warn(
"TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and"
" will be removed in a future version. Please use Flax logical axes with a"
" flax.linen.logical_axis_rules context and optionally use"
" transformer_engine.jax.flax.extend_logical_axis_rules to add BATCH_AXES, etc. to your"
" rules.",
DeprecationWarning,
)
# If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table
assert len(x.shape) == len(logical_axis_names)
pspec = generate_pspec(logical_axis_names)
return with_sharding_constraint(x, pspec)
......
......@@ -53,6 +53,7 @@ from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers
from transformer_engine.pytorch.export import onnx_export
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
try:
......
......@@ -57,6 +57,8 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
AttentionLogging as attn_log,
)
from transformer_engine.pytorch import export
from transformer_engine.pytorch.export import is_in_onnx_export_mode
# Global vars for flash attn v2 and v3 imports
flash_attn_cuda_bwd = None
......@@ -150,7 +152,14 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func)
def mask_func(x, y):
return (
export.onnx_attention_mask_func(x, y)
if is_in_onnx_export_mode()
else attention_mask_func(x, y)
)
self.scale_mask_softmax = FusedScaleMaskSoftmax(mask_func)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
......
......@@ -2559,8 +2559,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.enable_mla:
# [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn]
dk_fp8 = dkv_fp8[: ctx.k_numel].view(cp_size, *ctx.k_shape)
dv_fp8 = dkv_fp8[ctx.k_numel :].view(cp_size, *ctx.v_shape)
dk_fp8 = dkv_fp8[:, : ctx.k_numel].view(cp_size, *ctx.k_shape)
dv_fp8 = dkv_fp8[:, ctx.k_numel :].view(cp_size, *ctx.v_shape)
dk = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dk_fp8, fake_dtype=torch.float32, internal=True
)
......@@ -2586,8 +2586,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
dk = dk.view(*dk.shape[0], -1, *dk.shape[-2:])
dv = dv.view(*dv.shape[0], -1, *dv.shape[-2:])
dk = dk.view(dk.shape[0], -1, *dk.shape[-2:])
dv = dv.view(dv.shape[0], -1, *dv.shape[-2:])
else:
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
......
......@@ -17,6 +17,7 @@ from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.constants import (
AttnMaskTypes,
AttnTypes,
......@@ -963,6 +964,13 @@ class DotProductAttention(TransformerEngineBaseModule):
inference_params=inference_params,
)
global _attention_backends
if is_in_onnx_export_mode():
# We do not want to call get_attention_backend() in ONNX mode
# and we want to avoid using any global variables like _attention_backends.
use_flash_attention = False
use_fused_attention = False
use_unfused_attention = True
else:
if (
_attention_backends["attention_params"] is None
or attention_params != _attention_backends["attention_params"]
......
......@@ -8,6 +8,7 @@ from typing import Callable, Tuple, Union, Optional
import torch
from torch import nn
import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode
THREADS_PER_WARP = 32
......@@ -19,12 +20,18 @@ _default_causal_mask = {}
def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input"""
matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask:
def _get_mask():
diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1
_default_causal_mask[matrix_identifiers] = torch.triu(
return torch.triu(
torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset
)
if is_in_onnx_export_mode():
return _get_mask()
matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask:
_default_causal_mask[matrix_identifiers] = _get_mask()
return _default_causal_mask[matrix_identifiers]
......@@ -169,7 +176,11 @@ class FusedScaleMaskSoftmax(nn.Module):
self.attn_mask_type = attn_mask_type
assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled"
if is_in_onnx_export_mode():
return self.forward_torch_softmax(inp, mask, scale)
# We do not want to connect this if with previous if,
# because we want to avoid calling is_kernel_available() in ONNX mode.
if self.is_kernel_available(mask, *inp.size()):
return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale)
......@@ -245,15 +256,15 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type in ["causal", "causal_bottom_right"]:
seq_len_q, seq_len_k = inp.size(2), inp.size(3)
causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k)
if mask is None:
mask = causal_mask
else:
mask = torch.logical_or(mask, causal_mask)
mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask":
mask_output = self.mask_func(inp, mask)
probs = torch.nn.Softmax(dim=-1)(mask_output)
probs = torch.nn.functional.softmax(mask_output, dim=-1)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
......
......@@ -44,6 +44,7 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser
......@@ -105,7 +106,7 @@ class FlashAttentionUtils:
version = PkgVersion("0")
version_required = PkgVersion("2.1.1")
version_required_blackwell = PkgVersion("2.7.3")
max_version = PkgVersion("2.7.4.post1")
max_version = PkgVersion("2.8.1")
v2_plus = False
v2_1_plus = False
v2_3_plus = False
......@@ -437,8 +438,8 @@ def get_attention_backend(
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if inference_params is not None:
if device_compute_capability == (8, 9) and cudnn_version < (9, 11, 0):
logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.11")
if device_compute_capability == (8, 9) and cudnn_version < (9, 12, 0):
logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.12")
use_fused_attention = False
if context_parallel:
logger.debug("Disabling all backends for KV caching with context parallelism")
......@@ -624,6 +625,12 @@ def get_attention_backend(
" bias for THD format"
)
use_fused_attention = False
elif fp8 and head_dim_qk != head_dim_v:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with FP8"
" MLA attention"
)
use_fused_attention = False
# Filter: Attention mask
# attn_mask_type | attention_mask | supported backends
......@@ -1150,9 +1157,7 @@ def get_full_mask(
swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q + window_size[1]
).view(batch_size, 1, 1, 1)
swa_mask = torch.logical_not(
torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0)
)
swa_mask = torch.logical_not((swa_left <= 0) & ~(swa_right < 0))
if attention_mask is not None:
attention_mask = torch.logical_or(swa_mask, attention_mask)
else:
......@@ -1343,14 +1348,22 @@ def get_full_cu_seqlens(
"""
global _cu_seqlens_cache
if (batch_size, max_seqlen) not in _cu_seqlens_cache:
_cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
def _get_cu_seqlens(batch_size, max_seqlen, device):
return torch.arange(
0,
(batch_size + 1) * max_seqlen,
step=max_seqlen,
dtype=torch.int32,
device=device,
)
if is_in_onnx_export_mode():
return _get_cu_seqlens(batch_size, max_seqlen, device)
if (batch_size, max_seqlen) not in _cu_seqlens_cache:
_cu_seqlens_cache[(batch_size, max_seqlen)] = _get_cu_seqlens(
batch_size, max_seqlen, device
)
return _cu_seqlens_cache[(batch_size, max_seqlen)]
......@@ -1626,6 +1639,11 @@ def get_qkv_layout(
def run_iteratively(q, k, v):
# check data pointers
if is_in_onnx_export_mode():
check_ptrs_qkv = False
check_ptrs_qk = False
check_ptrs_kv = False
else:
data_ptr = q.untyped_storage().data_ptr()
check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k])
......@@ -1718,7 +1736,10 @@ def get_qkv_layout(
return qkv_layout
if not is_in_onnx_export_mode():
qkv_layout = run_iteratively(q, k, v)
else:
qkv_layout = "not_supported"
if qkv_layout == "not_supported":
# force q,k,v to be contiguous and run get_layout again
q, k, v = [x.contiguous() for x in [q, k, v]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment