Commit 44740c6c authored by yuguo's avatar yuguo
Browse files

Merge commit '7a9a0825' of...

Merge commit '7a9a0825' of https://github.com/NVIDIA/TransformerEngine
parents 8113d9e0 7a9a0825
...@@ -261,6 +261,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -261,6 +261,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING; bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING;
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
bool const is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
size_t input_dtype_bytes = te_dtype_bytes(in_dtype); size_t input_dtype_bytes = te_dtype_bytes(in_dtype);
size_t output_dtype_bytes = te_dtype_bytes(out_dtype); size_t output_dtype_bytes = te_dtype_bytes(out_dtype);
...@@ -314,6 +315,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -314,6 +315,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
size_t colwise_sinv_size = 0; size_t colwise_sinv_size = 0;
size_t non_group_m = flatten_axis > 1 ? product(input_dims, 1, flatten_axis) : 1; size_t non_group_m = flatten_axis > 1 ? product(input_dims, 1, flatten_axis) : 1;
size_t num_non_empty_groups = 0; size_t num_non_empty_groups = 0;
size_t total_rowwise_sinv_size = 0;
size_t total_colwise_sinv_size = 0;
for (size_t i = 0; i < num_groups; i++) { for (size_t i = 0; i < num_groups; i++) {
size_t m_i = dim_list_host[i] * non_group_m; size_t m_i = dim_list_host[i] * non_group_m;
// Skip for zero-size input + shiff the scale ptr // Skip for zero-size input + shiff the scale ptr
...@@ -379,6 +382,12 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty ...@@ -379,6 +382,12 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty
sinv_ptr += sinv_size * sinv_dtype_bytes; sinv_ptr += sinv_size * sinv_dtype_bytes;
colwise_sinv_ptr += colwise_sinv_size * colwise_sinv_dtype_bytes; colwise_sinv_ptr += colwise_sinv_size * colwise_sinv_dtype_bytes;
amax_ptr += amax_dtype_bytes; amax_ptr += amax_dtype_bytes;
total_rowwise_sinv_size += sinv_size;
total_colwise_sinv_size += colwise_sinv_size;
}
if (is_mxfp8_scaling) {
nvte_memset(scale_invs->untyped_data(), 0, total_rowwise_sinv_size, stream);
nvte_memset(colwise_scale_invs->untyped_data(), 0, total_colwise_sinv_size, stream);
} }
QuantizationConfigWrapper quant_config; QuantizationConfigWrapper quant_config;
......
...@@ -8,7 +8,7 @@ architectures, including support for quantization and automatic differentiation. ...@@ -8,7 +8,7 @@ architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations. customizable contracting dimensions for flexible tensor operations.
""" """
import warnings
from typing import Tuple, Sequence from typing import Tuple, Sequence
from functools import partial from functools import partial
import jax import jax
...@@ -19,9 +19,20 @@ from .quantize import ( ...@@ -19,9 +19,20 @@ from .quantize import (
QuantizerSet, QuantizerSet,
noop_quantizer_set, noop_quantizer_set,
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
TensorUsage,
) )
DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global DENSE_BATCH_FIRST_WARNING_ISSUED
if not DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
DENSE_BATCH_FIRST_WARNING_ISSUED = True
def dense( def dense(
x: jnp.ndarray, x: jnp.ndarray,
kernel: jnp.ndarray, kernel: jnp.ndarray,
...@@ -29,6 +40,7 @@ def dense( ...@@ -29,6 +40,7 @@ def dense(
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None, input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
"""Perform dense layer transformation with optional quantization. """Perform dense layer transformation with optional quantization.
...@@ -42,25 +54,28 @@ def dense( ...@@ -42,25 +54,28 @@ def dense(
kernel: Weight matrix for the dense layer transformation kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_first: Assume that X is batched in the first dimension.
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
# Remove when tex.quantize() can handle quantizer=None # Remove when tex.quantize() can handle quantizer=None
if quantizer_set == noop_quantizer_set: if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot():
x = with_sharding_constraint_by_logical_axes(x, input_axes) x = with_sharding_constraint_by_logical_axes(x, input_axes)
output = tex.gemm(x, kernel, contracting_dims) output = tex.gemm(x, kernel, contracting_dims=contracting_dims)
if bias is not None: if bias is not None:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
else: else:
output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set) output = _dense(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set):
"""Internal implementation of dense layer transformation with custom VJP. """Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support This function implements the core dense layer transformation logic with support
...@@ -74,81 +89,126 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer ...@@ -74,81 +89,126 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
output, _ = _dense_fwd_rule( output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
) )
return output return output
def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): def _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
):
"""Forward pass rule for dense layer transformation. """Forward pass rule for dense layer transformation.
Returns: Returns:
Tuple of (output, context) for backward pass Tuple of (output, context) for backward pass
""" """
x_contracting_dims, k_contracting_dims = contracting_dims x_contracting_dims, k_contracting_dims = map(
tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims
)
# Check supported input layout
x_is_transposed = x.ndim - 1 not in x_contracting_dims
k_is_transposed = kernel.ndim - 1 in k_contracting_dims
assert (
not x_is_transposed and not k_is_transposed
), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."
# Determine X batch dimension
# - If `batch_first=True` -> (batch, leading..., contracting...)
# - Otherwise -> (leading..., batch, contracting...)
# NOTE: Always assume a single batch dimension
x_bdim = None
num_cdims = len(x_contracting_dims)
if x.ndim >= num_cdims + 2:
# Assume X is batched if it has at least +2 dimensions more than the number of contracting
# dimensions.
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `dense()` layer implementation does not officially support sequence-first "
"inputs and may produce incorrect results when `batch_first=False`. Use "
"sequence-first inputs at your own discretion.",
)
x_bdim = 0 if batch_first else x.ndim - num_cdims - 1
flatten_axis_x = -len(x_contracting_dims) flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x) casted_x = tex.quantize(
x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True
)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
casted_kernel = tex.quantize( casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel kernel,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.kernel,
noop_scaled_tensor=True,
) )
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# GEMM NN # GEMM NN
use_bias = bias is not None
output = tex.gemm( output = tex.gemm(
casted_x.get_rowwise_tensor(), casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_colwise_tensor(), casted_kernel.get_tensor(usage=TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
) )
use_bias = bias is not None if use_bias and tex.gemm_uses_jax_dot():
if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
ctx = ( ctx = (
casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None, casted_x.get_tensor(usage=TensorUsage.LHS_TRANS),
casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None, casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS),
x.shape, x.shape,
kernel.shape, kernel.shape,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k, flatten_axis_k,
x_bdim,
) )
return output, ctx return output, ctx
def _dense_bwd_rule( def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation. """Backward pass rule for dense layer transformation.
Returns: Returns:
Tuple of gradients with respect to inputs Tuple of gradients with respect to inputs
""" """
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
( (
colwise_casted_x, casted_x_lhs,
rowwise_casted_kernel, casted_kernel_rhs,
x_shape, x_shape,
kernel_shape, kernel_shape,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k, flatten_axis_k,
x_bdim,
) = ctx ) = ctx
fwd_x_contracting_dims, fwd_k_contracting_dims = map(
tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
)
casted_grad, dbias = tex.quantize_dbias( casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad grad,
is_dbias=use_bias,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
) )
# GEMM NT # GEMM NT
...@@ -161,9 +221,10 @@ def _dense_bwd_rule( ...@@ -161,9 +221,10 @@ def _dense_bwd_rule(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
) )
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(), casted_grad.get_tensor(usage=TensorUsage.LHS),
rowwise_casted_kernel, casted_kernel_rhs,
(g_contracting_dim, k_contracting_dim), contracting_dims=(g_contracting_dim, k_contracting_dim),
batched_dims=((x_bdim,), ()),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
...@@ -174,7 +235,10 @@ def _dense_bwd_rule( ...@@ -174,7 +235,10 @@ def _dense_bwd_rule(
) )
wgrad = tex.gemm( wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim) casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
...@@ -287,7 +351,6 @@ def _grouped_dense_fwd_rule( ...@@ -287,7 +351,6 @@ def _grouped_dense_fwd_rule(
"and k_contracting_dims=(1,) for now, " "and k_contracting_dims=(1,) for now, "
f"got {x_contracting_dims=} and {k_contracting_dims=}" f"got {x_contracting_dims=} and {k_contracting_dims=}"
) )
k_contracting_dims = (0,)
casted_x = tex.grouped_quantize( casted_x = tex.grouped_quantize(
x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x
...@@ -300,11 +363,10 @@ def _grouped_dense_fwd_rule( ...@@ -300,11 +363,10 @@ def _grouped_dense_fwd_rule(
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
# rowwise_casted_x.original_shape == (M, K) # rowwise_casted_x.original_shape == (M, K)
# colwise_casted_kernel.original_shape == (G, N, K) # colwise_casted_kernel.original_shape == (G, N, K)
grouped_gemm_x = casted_x.get_rowwise_tensor() grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS)
grouped_gemm_kernel = casted_kernel.get_colwise_tensor() grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)
# TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()? ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
ctx_x = casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS)
ctx_kernel = casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None
output = tex.grouped_gemm( output = tex.grouped_gemm(
grouped_gemm_x, grouped_gemm_x,
...@@ -382,17 +444,17 @@ def _grouped_dense_bwd_rule( ...@@ -382,17 +444,17 @@ def _grouped_dense_bwd_rule(
g_contracting_dim = (1,) g_contracting_dim = (1,)
k_contracting_dim = (2,) k_contracting_dim = (2,)
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_grad = casted_grad.get_rowwise_tensor() dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS)
dgrad_kernel_T = ctx_kernel dgrad_kernel_T = ctx_kernel
# We need to use g_contracting_dim = (0,) and x_contracting_dim = (1,) to make it work # We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work
# after the extra transpose for FP8 in grouped_gemm # after the extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (0,) g_contracting_dim = (0,)
x_contracting_dim = (0,) x_contracting_dim = (0,)
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_x_T = ctx_x wgrad_x_T = ctx_x
wgrad_grad = casted_grad.get_colwise_tensor() wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS)
dgrad = tex.grouped_gemm( dgrad = tex.grouped_gemm(
dgrad_grad, dgrad_grad,
......
...@@ -6,7 +6,7 @@ Wrapper module for Transformer related layers with FP8 support. ...@@ -6,7 +6,7 @@ Wrapper module for Transformer related layers with FP8 support.
""" """
from functools import reduce from functools import reduce
import operator import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
import numpy as np import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
...@@ -15,12 +15,12 @@ from jax import lax ...@@ -15,12 +15,12 @@ from jax import lax
from jax import random as jax_random from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from ..dense import dense from ..dense import dense, _issue_batch_first_warning as _dense_warning
from ..layernorm import canonicalize_norm_type from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning
from ..layernorm_mlp import layernorm_mlp from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning
from ..activation import activation from ..activation import activation
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
...@@ -35,8 +35,8 @@ from ..sharding import get_non_contracting_logical_axes ...@@ -35,8 +35,8 @@ from ..sharding import get_non_contracting_logical_axes
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
DType = jnp.dtype DType = NewType("DType", jnp.dtype)
Array = jnp.ndarray Array = NewType("Array", jnp.ndarray)
PrecisionLike = Union[ PrecisionLike = Union[
None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
] ]
...@@ -441,6 +441,12 @@ class DenseGeneral(TransformerEngineBase): ...@@ -441,6 +441,12 @@ class DenseGeneral(TransformerEngineBase):
input_axes: Tuple[str, ...] = () input_axes: Tuple[str, ...] = ()
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_dense_warning(
"TE/JAX DenseGeneral() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype 1.0, "fan_in", "truncated_normal", dtype=self.dtype
...@@ -657,6 +663,12 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -657,6 +663,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
depth_scaling: float = None depth_scaling: float = None
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_ln_dense_warning(
"TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first "
"inputs and may produce incorrect results when `transpose_batch_sequence=True`. "
"Use sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, 1.0,
...@@ -967,6 +979,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -967,6 +979,12 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes: Tuple[str, ...] = None dot_2_input_axes: Tuple[str, ...] = None
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_ln_mlp_warning(
"TE/JAX LayerNormMLP() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype 1.0, "fan_in", "truncated_normal", dtype=self.dtype
......
...@@ -180,8 +180,9 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -180,8 +180,9 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_weights_without_groups_shape = (b, h * g, q, k) attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
# (b, h, q, k): Last two axes are always replicated
attn_weights = with_sharding_constraint_by_logical_axes( attn_weights = with_sharding_constraint_by_logical_axes(
attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES) attn_weights, (BATCH_AXES, HEAD_AXES, None, None)
) )
# When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias) # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
......
...@@ -9,6 +9,7 @@ architectures. It supports various normalization types, quantization, and ...@@ -9,6 +9,7 @@ architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints. distributed training through sharding constraints.
""" """
import warnings
from functools import partial from functools import partial
from typing import Tuple from typing import Tuple
...@@ -21,9 +22,20 @@ from .quantize import ( ...@@ -21,9 +22,20 @@ from .quantize import (
QuantizerSet, QuantizerSet,
noop_quantizer_set, noop_quantizer_set,
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
TensorUsage,
) )
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_dense( def layernorm_dense(
x: jnp.ndarray, x: jnp.ndarray,
kernel: jnp.ndarray, kernel: jnp.ndarray,
...@@ -36,6 +48,7 @@ def layernorm_dense( ...@@ -36,6 +48,7 @@ def layernorm_dense(
layernorm_input_axes: Tuple[str, ...] = None, layernorm_input_axes: Tuple[str, ...] = None,
dot_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation. """Apply layer normalization followed by dense layer transformation.
...@@ -56,6 +69,7 @@ def layernorm_dense( ...@@ -56,6 +69,7 @@ def layernorm_dense(
layernorm_input_axes: Logical axes for sharding the layernorm input layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_set: Set of quantizers for different tensor types quantizer_set: Set of quantizers for different tensor types
Returns: Returns:
...@@ -79,6 +93,7 @@ def layernorm_dense( ...@@ -79,6 +93,7 @@ def layernorm_dense(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -93,6 +108,7 @@ def layernorm_dense( ...@@ -93,6 +108,7 @@ def layernorm_dense(
8, 8,
9, 9,
10, 10,
11,
), ),
) )
def _layernorm_dense( def _layernorm_dense(
...@@ -107,6 +123,7 @@ def _layernorm_dense( ...@@ -107,6 +123,7 @@ def _layernorm_dense(
layernorm_input_axes: Tuple[str, ...], layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...],
batch_first: bool,
quantizer_set, quantizer_set,
): ):
"""Internal implementation of layernorm_dense with custom VJP. """Internal implementation of layernorm_dense with custom VJP.
...@@ -126,6 +143,7 @@ def _layernorm_dense( ...@@ -126,6 +143,7 @@ def _layernorm_dense(
epsilon: Small constant for numerical stability epsilon: Small constant for numerical stability
layernorm_input_axes: Logical axes for layernorm sharding layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding dot_input_axes: Logical axes for matrix multiplication sharding
batch_first: Assume that X is batched in the first dimension.
quantizer_set: Set of quantizers quantizer_set: Set of quantizers
Returns: Returns:
...@@ -143,6 +161,7 @@ def _layernorm_dense( ...@@ -143,6 +161,7 @@ def _layernorm_dense(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -160,6 +179,7 @@ def _layernorm_dense_fwd_rule( ...@@ -160,6 +179,7 @@ def _layernorm_dense_fwd_rule(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
): ):
"""Forward pass rule for layernorm_dense. """Forward pass rule for layernorm_dense.
...@@ -177,6 +197,17 @@ def _layernorm_dense_fwd_rule( ...@@ -177,6 +197,17 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims = (0,) k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0] assert x.shape[-1] == kernel.shape[0]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_dense()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
casted_ln_out, mu, rsigma = tex.normalization_fwd( casted_ln_out, mu, rsigma = tex.normalization_fwd(
...@@ -186,31 +217,37 @@ def _layernorm_dense_fwd_rule( ...@@ -186,31 +217,37 @@ def _layernorm_dense_fwd_rule(
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
norm_type, norm_type,
quantizer_set.x, quantizer=quantizer_set.x,
noop_scaled_tensor=True,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# Kernel in (hidden_in, hidden_out...) # Kernel in (hidden_in, hidden_out...)
flatten_axis = 1 - len(kernel.shape) flatten_axis = 1 - len(kernel.shape)
casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel) casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...) # (batch..., hidden_in) x (hidden_in, hidden_out...)
use_bias = bias is not None
output = tex.gemm( output = tex.gemm(
casted_ln_out.get_rowwise_tensor(), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_colwise_tensor(), casted_kernel.get_tensor(TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
) )
use_bias = bias is not None if use_bias and tex.gemm_uses_jax_dot():
if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
ctx = ( ctx = (
casted_ln_out.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None, casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None, casted_kernel.get_tensor(TensorUsage.RHS_TRANS),
x.shape, x.shape,
kernel.shape, kernel.shape,
mu, mu,
...@@ -223,6 +260,7 @@ def _layernorm_dense_fwd_rule( ...@@ -223,6 +260,7 @@ def _layernorm_dense_fwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis, flatten_axis,
x_bdim,
) )
return output, ctx return output, ctx
...@@ -235,6 +273,7 @@ def _layernorm_dense_bwd_rule( ...@@ -235,6 +273,7 @@ def _layernorm_dense_bwd_rule(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument dot_input_axes, # pylint: disable=unused-argument
kernel_axes, kernel_axes,
batch_first, # pylint: disable=unused-argument
ctx, ctx,
grad, grad,
): ):
...@@ -250,8 +289,8 @@ def _layernorm_dense_bwd_rule( ...@@ -250,8 +289,8 @@ def _layernorm_dense_bwd_rule(
Tuple of gradients for all input parameters Tuple of gradients for all input parameters
""" """
( (
colwise_casted_ln_out, casted_ln_out,
rowwise_casted_kernel, casted_kernel,
x_shape, x_shape,
kernel_shape, kernel_shape,
mu, mu,
...@@ -264,10 +303,15 @@ def _layernorm_dense_bwd_rule( ...@@ -264,10 +303,15 @@ def _layernorm_dense_bwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis, flatten_axis,
x_bdim,
) = ctx ) = ctx
casted_grad, dbias = tex.quantize_dbias( casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad grad,
is_dbias=use_bias,
flatten_axis=flatten_axis,
quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
...@@ -281,9 +325,10 @@ def _layernorm_dense_bwd_rule( ...@@ -281,9 +325,10 @@ def _layernorm_dense_bwd_rule(
# NT GEMM # NT GEMM
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(), casted_grad.get_tensor(TensorUsage.LHS),
rowwise_casted_kernel, casted_kernel,
(g_constracting_dim, k_constracting_dim), contracting_dims=(g_constracting_dim, k_constracting_dim),
batched_dims=((x_bdim,), ()),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
...@@ -294,9 +339,10 @@ def _layernorm_dense_bwd_rule( ...@@ -294,9 +339,10 @@ def _layernorm_dense_bwd_rule(
# TN GEMM # TN GEMM
wgrad = tex.gemm( wgrad = tex.gemm(
colwise_casted_ln_out, casted_ln_out,
casted_grad.get_colwise_tensor(), casted_grad.get_tensor(TensorUsage.RHS),
(x_constracting_dim, g_constracting_dim), contracting_dims=(x_constracting_dim, g_constracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
...@@ -13,6 +13,7 @@ The implementation supports various normalization types, activation functions, ...@@ -13,6 +13,7 @@ The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints. quantization, and distributed training through sharding constraints.
""" """
import warnings
from typing import List, Tuple, Sequence, Union, Callable from typing import List, Tuple, Sequence, Union, Callable
from functools import partial from functools import partial
...@@ -22,10 +23,25 @@ from jax.ad_checkpoint import checkpoint_name ...@@ -22,10 +23,25 @@ from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .layernorm import canonicalize_norm_type from .layernorm import canonicalize_norm_type
from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set from .quantize import (
with_sharding_constraint_by_logical_axes,
QuantizerSet,
noop_quantizer_set,
TensorUsage,
)
from .sharding import get_non_contracting_logical_axes from .sharding import get_non_contracting_logical_axes
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_mlp( def layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -43,6 +59,7 @@ def layernorm_mlp( ...@@ -43,6 +59,7 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1", ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2", ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
batch_first: bool = True,
quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by MLP block. """Apply layer normalization followed by MLP block.
...@@ -74,6 +91,7 @@ def layernorm_mlp( ...@@ -74,6 +91,7 @@ def layernorm_mlp(
ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation activation_type: Activation function(s) to apply after the first dense layer transformation
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns: Returns:
...@@ -119,12 +137,13 @@ def layernorm_mlp( ...@@ -119,12 +137,13 @@ def layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
quantizer_sets, quantizer_sets,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
def _layernorm_mlp( def _layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -144,6 +163,7 @@ def _layernorm_mlp( ...@@ -144,6 +163,7 @@ def _layernorm_mlp(
ffn1_ckpt_name: str, ffn1_ckpt_name: str,
ffn2_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
batch_first: bool,
quantizer_sets, quantizer_sets,
): ):
"""Internal implementation of layernorm_mlp with custom VJP. """Internal implementation of layernorm_mlp with custom VJP.
...@@ -169,6 +189,7 @@ def _layernorm_mlp( ...@@ -169,6 +189,7 @@ def _layernorm_mlp(
ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s) activation_type: Activation function(s)
batch_first: Assume that X is batched in the first dimension.
quantizer_sets: Tuple of quantizer sets quantizer_sets: Tuple of quantizer sets
Returns: Returns:
...@@ -193,6 +214,7 @@ def _layernorm_mlp( ...@@ -193,6 +214,7 @@ def _layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
quantizer_sets, quantizer_sets,
) )
return output return output
...@@ -217,6 +239,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -217,6 +239,7 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
quantizer_sets, quantizer_sets,
): ):
"""Forward pass rule for layernorm_mlp. """Forward pass rule for layernorm_mlp.
...@@ -249,6 +272,17 @@ def _layernorm_mlp_fwd_rule( ...@@ -249,6 +272,17 @@ def _layernorm_mlp_fwd_rule(
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_mlp()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
use_bias_1 = bias_1 is not None use_bias_1 = bias_1 is not None
use_bias_2 = bias_1 is not None use_bias_2 = bias_1 is not None
...@@ -262,17 +296,23 @@ def _layernorm_mlp_fwd_rule( ...@@ -262,17 +296,23 @@ def _layernorm_mlp_fwd_rule(
epsilon, epsilon,
norm_type, norm_type,
quantizer=ffn1_quantizer_set.x, quantizer=ffn1_quantizer_set.x,
noop_scaled_tensor=True,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel) casted_kernel_1 = tex.quantize(
kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True
)
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out) # (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = tex.gemm( dot_1_output = tex.gemm(
casted_ln_out.get_rowwise_tensor(), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_colwise_tensor(), casted_kernel_1.get_tensor(TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
) )
if dot_1_input_axes is not None and kernel_1_axes is not None: if dot_1_input_axes is not None and kernel_1_axes is not None:
...@@ -282,7 +322,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -282,7 +322,7 @@ def _layernorm_mlp_fwd_rule(
) )
dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
if use_bias_1: if use_bias_1 and tex.gemm_uses_jax_dot():
bias_1_shape = bias_1.shape bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
dot_1_output += jnp.reshape(bias_1, bias_1_new_shape) dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
...@@ -290,21 +330,28 @@ def _layernorm_mlp_fwd_rule( ...@@ -290,21 +330,28 @@ def _layernorm_mlp_fwd_rule(
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
# (batch..., hidden_in) -> (batch..., hidden) # (batch..., hidden_in) -> (batch..., hidden)
casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x) casted_act_out = tex.act_lu(
dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True
)
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel) casted_kernel_2 = tex.quantize(
kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True
)
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_out, hidden_in) # (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = tex.gemm( dot_2_output = tex.gemm(
casted_act_out.get_rowwise_tensor(), casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_colwise_tensor(), casted_kernel_2.get_tensor(TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
) )
if use_bias_2: if use_bias_2 and tex.gemm_uses_jax_dot():
bias_2_shape = bias_2.shape bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
...@@ -317,11 +364,11 @@ def _layernorm_mlp_fwd_rule( ...@@ -317,11 +364,11 @@ def _layernorm_mlp_fwd_rule(
rsigma, rsigma,
gamma, gamma,
beta, beta,
casted_ln_out.get_colwise_tensor(), casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel_1.get_rowwise_tensor(), casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
dot_1_output, dot_1_output,
casted_act_out.get_colwise_tensor(), casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel_2.get_rowwise_tensor(), casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
x_contracting_dims, x_contracting_dims,
k_contracting_dims, k_contracting_dims,
kernel_1.shape, kernel_1.shape,
...@@ -329,6 +376,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -329,6 +376,7 @@ def _layernorm_mlp_fwd_rule(
use_bias_1, use_bias_1,
use_bias_2, use_bias_2,
quantizer_sets, quantizer_sets,
x_bdim,
) )
return dot_2_output, ctx return dot_2_output, ctx
...@@ -346,6 +394,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -346,6 +394,7 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
ctx, ctx,
grad, grad,
): ):
...@@ -362,18 +411,18 @@ def _layernorm_mlp_bwd_rule( ...@@ -362,18 +411,18 @@ def _layernorm_mlp_bwd_rule(
Returns: Returns:
Tuple of gradients for all input parameters Tuple of gradients for all input parameters
""" """
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first
( (
x, x,
mu, mu,
rsigma, rsigma,
gamma, gamma,
beta, beta,
colwise_casted_ln_out, casted_ln_out,
rowwise_casted_kernel_1, casted_kernel_1,
dot_1_output, dot_1_output,
colwise_casted_act_out, casted_act_out,
rowwise_casted_kernel_2, casted_kernel_2,
x_contracting_dims_in_fwd, x_contracting_dims_in_fwd,
k_contracting_dims_in_fwd, k_contracting_dims_in_fwd,
kernel_1_shape, kernel_1_shape,
...@@ -381,6 +430,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -381,6 +430,7 @@ def _layernorm_mlp_bwd_rule(
use_bias_1, use_bias_1,
use_bias_2, use_bias_2,
quantizer_sets, quantizer_sets,
x_bdim,
) = ctx ) = ctx
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
...@@ -389,7 +439,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -389,7 +439,7 @@ def _layernorm_mlp_bwd_rule(
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
casted_grad, dbias_2 = tex.quantize_dbias( casted_grad, dbias_2 = tex.quantize_dbias(
grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
...@@ -404,9 +454,10 @@ def _layernorm_mlp_bwd_rule( ...@@ -404,9 +454,10 @@ def _layernorm_mlp_bwd_rule(
# NT GEMM # NT GEMM
# (batch..., hidden_out) x (hidden_in, hidden_out) # (batch..., hidden_out) x (hidden_in, hidden_out)
dgrad_2 = tex.gemm( dgrad_2 = tex.gemm(
casted_grad.get_rowwise_tensor(), casted_grad.get_tensor(TensorUsage.LHS),
rowwise_casted_kernel_2, casted_kernel_2,
(g_contracting_dims_2, k_contracting_dims_2), contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
batched_dims=((x_bdim,), ()),
) )
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
...@@ -418,9 +469,10 @@ def _layernorm_mlp_bwd_rule( ...@@ -418,9 +469,10 @@ def _layernorm_mlp_bwd_rule(
# TN GEMM # TN GEMM
# (hidden, batch...,) x (hidden, batch...) # (hidden, batch...,) x (hidden, batch...)
wgrad_2 = tex.gemm( wgrad_2 = tex.gemm(
colwise_casted_act_out, casted_act_out,
casted_grad.get_colwise_tensor(), casted_grad.get_tensor(TensorUsage.RHS),
(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
...@@ -430,10 +482,11 @@ def _layernorm_mlp_bwd_rule( ...@@ -430,10 +482,11 @@ def _layernorm_mlp_bwd_rule(
activation_type=activation_type, activation_type=activation_type,
is_dbias=use_bias_1, is_dbias=use_bias_1,
quantizer=ffn2_quantizer_set.dgrad, quantizer=ffn2_quantizer_set.dgrad,
noop_scaled_tensor=True,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
g_contracting_dims_1 = tuple( g_contracting_dims_1 = tuple(
range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim) range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
) )
...@@ -444,9 +497,10 @@ def _layernorm_mlp_bwd_rule( ...@@ -444,9 +497,10 @@ def _layernorm_mlp_bwd_rule(
# NT GEMM # NT GEMM
dgrad_1 = tex.gemm( dgrad_1 = tex.gemm(
casted_dact_out.get_rowwise_tensor(), casted_dact_out.get_tensor(TensorUsage.LHS),
rowwise_casted_kernel_1, casted_kernel_1,
(g_contracting_dims_1, k_contracting_dims_1), contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
batched_dims=((x_bdim,), ()),
) )
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
...@@ -454,9 +508,10 @@ def _layernorm_mlp_bwd_rule( ...@@ -454,9 +508,10 @@ def _layernorm_mlp_bwd_rule(
# TN GEMM # TN GEMM
# (hidden, batch...) x (hidden, batch...) # (hidden, batch...) x (hidden, batch...)
wgrad_1 = tex.gemm( wgrad_1 = tex.gemm(
colwise_casted_ln_out, casted_ln_out,
casted_dact_out.get_colwise_tensor(), casted_dact_out.get_tensor(TensorUsage.RHS),
(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
...@@ -15,3 +15,4 @@ from .dequantizer import * ...@@ -15,3 +15,4 @@ from .dequantizer import *
from .scaling_modes import * from .scaling_modes import *
from .metadata import * from .metadata import *
from .helper import * from .helper import *
from .device_utils import *
...@@ -36,6 +36,22 @@ class Dequantizer(ABC): ...@@ -36,6 +36,22 @@ class Dequantizer(ABC):
"""Dequantizing given tensor to higher precision.""" """Dequantizing given tensor to higher precision."""
@dataclass
class NoopDequantizer(Dequantizer):
"""No-op Dequantizer Class"""
@staticmethod
def _dequantize_func(data, *args, **kwargs):
"""A no-op dequantize function that returns the data without any changes."""
del args, kwargs
return data
@staticmethod
def dequantize(scaled_tensor):
"""A no-op dequantize function that simply returns the data array in the ScaledTensor."""
return scaled_tensor.data
class TensorScaleDequantizer(Dequantizer): class TensorScaleDequantizer(Dequantizer):
""" """
TensorScaling Dequantizer Class TensorScaling Dequantizer Class
...@@ -105,9 +121,6 @@ class BlockScaleDequantizer(Dequantizer): ...@@ -105,9 +121,6 @@ class BlockScaleDequantizer(Dequantizer):
scale_shape = scaling_mode.get_scale_shape( scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
) )
scale_inv = jax.lax.slice(
scale_inv, [0] * len(scale_shape), scale_shape
) # slice out the padding
data = data.reshape( data = data.reshape(
*data_shape[: flatten_axis - 1], *data_shape[: flatten_axis - 1],
...@@ -152,6 +165,7 @@ ScalingModeToDequantizerMap = { ...@@ -152,6 +165,7 @@ ScalingModeToDequantizerMap = {
ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer,
ScalingMode.NO_SCALING: NoopDequantizer,
} }
...@@ -194,28 +208,38 @@ def _grouped_dequantize(grouped_scaled_tensor): ...@@ -194,28 +208,38 @@ def _grouped_dequantize(grouped_scaled_tensor):
f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to" f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to"
f" {data_i.size}" f" {data_i.size}"
) )
scale_shape_i = scaling_mode.get_scale_shape( padded_scale_shape_i = scaling_mode.get_scale_shape(
data_shape_i, data_shape_i,
grouped_scaled_tensor.is_colwise, grouped_scaled_tensor.is_colwise,
is_padded=True, is_padded=True,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
scale_shape_i_size = math.prod(scale_shape_i) unpadded_scale_shape_i = scaling_mode.get_scale_shape(
scale_inv_i = scale_inv[scale_inv_ptr : scale_inv_ptr + scale_shape_i_size] data_shape_i,
grouped_scaled_tensor.is_colwise,
is_padded=False,
flatten_axis=flatten_axis,
)
scale_inv_i = scale_inv[
scale_inv_ptr : scale_inv_ptr + math.prod(padded_scale_shape_i)
].reshape(padded_scale_shape_i)
scale_inv_i = jax.lax.slice(
scale_inv_i, [0] * len(unpadded_scale_shape_i), unpadded_scale_shape_i
)
dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode) dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode)
if len(data_i) == 0: if len(data_i) == 0:
out_i = [] out_i = []
else: else:
out_i = dequantizer_type._dequantize_func( out_i = dequantizer_type._dequantize_func(
data_i.reshape(data_shape_i), data_i.reshape(data_shape_i),
scale_inv_i.reshape(scale_shape_i), scale_inv_i,
grouped_scaled_tensor.dq_dtype, grouped_scaled_tensor.dq_dtype,
scaling_mode=grouped_scaled_tensor.scaling_mode, scaling_mode=grouped_scaled_tensor.scaling_mode,
is_colwise=grouped_scaled_tensor.is_colwise, is_colwise=grouped_scaled_tensor.is_colwise,
flatten_axis=grouped_scaled_tensor.flatten_axis, flatten_axis=grouped_scaled_tensor.flatten_axis,
) )
output.append(out_i) output.append(out_i)
scale_inv_ptr += scale_shape_i_size scale_inv_ptr += math.prod(padded_scale_shape_i)
return output return output
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Device utility functions for JAX quantization.
This module provides utility functions for checking device capabilities and compatibility
for quantization operations in JAX.
"""
import functools
import transformer_engine_jax
__all__ = [
"get_device_compute_capability",
"is_fp8_gemm_with_all_layouts_supported",
]
@functools.lru_cache(maxsize=None)
def get_device_compute_capability(gpu_id: int = 0) -> int:
"""
Get the compute capability of the device.
"""
return transformer_engine_jax.get_device_compute_capability(gpu_id)
@functools.lru_cache(maxsize=None)
def is_fp8_gemm_with_all_layouts_supported() -> bool:
"""Return True if using Blackwell architecture, False otherwise."""
compute_capability = get_device_compute_capability()
return 100 <= compute_capability < 120
...@@ -9,23 +9,21 @@ in JAX, including support for different scaling modes and datatypes. ...@@ -9,23 +9,21 @@ in JAX, including support for different scaling modes and datatypes.
""" """
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
from typing import Optional, Tuple, Dict, Union from typing import Optional, Tuple, Dict, Union, Sequence
from functools import reduce
import operator
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from transformer_engine_jax import DType from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version
from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import (
get_cuda_version,
get_device_compute_capability,
)
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.jax.sharding import global_shard_guard, MeshResource from transformer_engine.jax.sharding import global_shard_guard, MeshResource
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
from .device_utils import get_device_compute_capability
__all__ = [ __all__ = [
"QuantizeConfig", "QuantizeConfig",
...@@ -33,6 +31,8 @@ __all__ = [ ...@@ -33,6 +31,8 @@ __all__ = [
"is_fp8_available", "is_fp8_available",
"update_collections", "update_collections",
"get_delayed_scaling", "get_delayed_scaling",
"apply_padding_to_scale_inv",
"remove_padding_from_scale_inv",
"NVTE_FP8_COLLECTION_NAME", "NVTE_FP8_COLLECTION_NAME",
] ]
...@@ -203,7 +203,7 @@ class QuantizeConfig: ...@@ -203,7 +203,7 @@ class QuantizeConfig:
FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass
FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients
FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients
IF_QUANTIZE_2X: Whether 2x quantization is enabled INFERENCE_MODE: Whether to enable optimization for inference
SCALING_MODE: Scaling mode SCALING_MODE: Scaling mode
AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling
AMAX_COMPUTE_ALGO: Algorithm for AMAX computation AMAX_COMPUTE_ALGO: Algorithm for AMAX computation
...@@ -218,7 +218,7 @@ class QuantizeConfig: ...@@ -218,7 +218,7 @@ class QuantizeConfig:
FP8_2X_ACC_FPROP: bool = False FP8_2X_ACC_FPROP: bool = False
FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False
IF_QUANTIZE_2X: bool = False INFERENCE_MODE: bool = False
SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING
# DelayedScaling # DelayedScaling
...@@ -246,7 +246,6 @@ class QuantizeConfig: ...@@ -246,7 +246,6 @@ class QuantizeConfig:
cls.FP8_FORMAT = fp8_recipe.fp8_format cls.FP8_FORMAT = fp8_recipe.fp8_format
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = _get_scaling_mode(fp8_recipe) cls.SCALING_MODE = _get_scaling_mode(fp8_recipe)
cls.IF_QUANTIZE_2X = True
@classmethod @classmethod
def finalize(cls) -> None: def finalize(cls) -> None:
...@@ -260,7 +259,7 @@ class QuantizeConfig: ...@@ -260,7 +259,7 @@ class QuantizeConfig:
cls.FP8_2X_ACC_DGRAD = False cls.FP8_2X_ACC_DGRAD = False
cls.FP8_2X_ACC_WGRAD = False cls.FP8_2X_ACC_WGRAD = False
cls.SCALING_MODE = ScalingMode.NO_SCALING cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.IF_QUANTIZE_2X = False cls.INFERENCE_MODE = False
# DelayedScaling # DelayedScaling
cls.AMAX_HISTORY_LEN = 1024 cls.AMAX_HISTORY_LEN = 1024
cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
...@@ -476,4 +475,115 @@ def update_collections(new: Collection, original: Collection) -> Collection: ...@@ -476,4 +475,115 @@ def update_collections(new: Collection, original: Collection) -> Collection:
return new_coll return new_coll
def remove_padding_from_scale_inv(
scale_inv: jax.Array,
scaling_mode: ScalingMode,
data_shape: Sequence[int],
is_colwise: bool = False,
flatten_axis: int = -1,
):
"""
Slice padding out of padded inverse scale factors.
Args:
scale_inv: Inverse scale factor.
data_shape: Shape of the quantized data the inverse scale belongs to.
scaling_mode: ScalingMode representing the quantization method.
is_colwise: Whether the data was quantized column-wise.
flatten_axis: The axis along with the data could be flattened to 2D.
Returns:
Inverse scale factor without padding.
"""
# Get expected unpadded scale shape and check if inverse scale already matches
unpadded_scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
)
if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == unpadded_scale_shape:
return scale_inv
# Get the padded scale shape and make sure inverse scale matches
padded_scale_shape = scaling_mode.get_scale_shape(
data_shape,
is_colwise=is_colwise,
is_padded=True,
flatten_axis=flatten_axis,
)
assert scale_inv.shape == padded_scale_shape, (
f"Padded inverse scale factor has wrong shape, expected {padded_scale_shape} but got "
f"{scale_inv.shape} instead."
)
# Reshape scale inverse to 2D in two stages to preserve the flatten axis
padded_scale_shape_2d = (
reduce(operator.mul, padded_scale_shape[:flatten_axis]),
reduce(operator.mul, padded_scale_shape[flatten_axis:]),
)
scale_inv_2d = jnp.reshape(
jnp.reshape(scale_inv, (padded_scale_shape_2d[0], *scale_inv.shape[flatten_axis:])),
padded_scale_shape_2d,
)
# Slice reshaped 2D scale inverse using collapsed 2D unpadded_scale_shape
unpadded_scale_shape_2d = (
reduce(operator.mul, unpadded_scale_shape[:flatten_axis]),
reduce(operator.mul, unpadded_scale_shape[flatten_axis:]),
)
scale_inv_2d_unpadded = jnp.asarray(
scale_inv_2d[: unpadded_scale_shape_2d[0], : unpadded_scale_shape_2d[1]]
)
# Reshape 2D scale inverse back in two stages in order to preserve the flatten axis
scale_inv_unpadded = jnp.reshape(
jnp.reshape(
scale_inv_2d_unpadded,
(*unpadded_scale_shape[:flatten_axis], scale_inv_2d_unpadded.shape[1]),
),
unpadded_scale_shape,
)
return scale_inv_unpadded
def apply_padding_to_scale_inv(
scale_inv: jax.Array,
scaling_mode: ScalingMode,
data_shape: Sequence[int],
is_colwise: bool = False,
flatten_axis: int = -1,
):
"""
Pad the scale inverse with zeros to match the necessary padded shape for this scaling
mode.
Args:
scale_inv: Inverse scale factor.
data_shape: Shape of the quantized data the inverse scale belongs to.
scaling_mode: ScalingMode representing the quantization method.
is_colwise: Whether the data was quantized column-wise.
flatten_axis: The axis along with the data could be flattened to 2D.
Returns:
Padded inverse scale factor.
"""
# Get the expected padded scale shape and check if inverse scale already matches
padded_scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise=is_colwise, is_padded=True, flatten_axis=flatten_axis
)
if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == padded_scale_shape:
return scale_inv
# Get the expected unpadded scale shape and make sure inverse scales match
unpadded_scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
)
assert scale_inv.shape == unpadded_scale_shape, (
f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got "
f"{scale_inv.shape}."
)
# Pad the scales with the lowest representable value (2^-127) and return
pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape))
return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127)
NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME
...@@ -23,6 +23,7 @@ from .helper import ( ...@@ -23,6 +23,7 @@ from .helper import (
QuantizeConfig, QuantizeConfig,
AmaxComputeAlgo, AmaxComputeAlgo,
) )
from .device_utils import is_fp8_gemm_with_all_layouts_supported
__all__ = [ __all__ = [
"QuantizeLayout", "QuantizeLayout",
...@@ -607,9 +608,10 @@ class GroupedQuantizer(Quantizer): ...@@ -607,9 +608,10 @@ class GroupedQuantizer(Quantizer):
def __post_init__(self): def __post_init__(self):
if self.quantizers[0] is None: if self.quantizers[0] is None:
self.quantizers = QuantizerFactory.create( quantizers = QuantizerFactory.create(
self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout
) )
self.quantizers = (quantizers,) if not isinstance(quantizers, tuple) else quantizers
self.data_layout = self.quantizers[0].data_layout self.data_layout = self.quantizers[0].data_layout
def _create_grouped_tensor_from_tensor_list( def _create_grouped_tensor_from_tensor_list(
...@@ -841,9 +843,11 @@ class QuantizerFactory: ...@@ -841,9 +843,11 @@ class QuantizerFactory:
if is_2x2x: if is_2x2x:
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
else: else:
q_layout_x = QuantizeLayout.ROWWISE q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE
q_layout_kernel = QuantizeLayout.COLWISE if scaling_mode.is_1d_block_scaling():
q_layout_dgrad = None q_layout_kernel = QuantizeLayout.COLWISE
if QuantizeConfig.INFERENCE_MODE:
q_layout_dgrad = None
if "quantize_meta_set" in kwargs: if "quantize_meta_set" in kwargs:
quantize_meta_set = kwargs.get("quantize_meta_set") quantize_meta_set = kwargs.get("quantize_meta_set")
...@@ -898,7 +902,15 @@ class QuantizerFactory: ...@@ -898,7 +902,15 @@ class QuantizerFactory:
scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE
bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE
is_2x2x = is_2x2x or QuantizeConfig.IF_QUANTIZE_2X if is_2x2x is None:
if scaling_mode.is_1d_block_scaling():
is_2x2x = True
elif scaling_mode.is_tensor_scaling():
is_2x2x = not is_fp8_gemm_with_all_layouts_supported()
else: # NO_SCALING ignores is_2x2x for now
is_2x2x = False
is_inference_mode = QuantizeConfig.INFERENCE_MODE
assert not is_inference_mode, "Inference mode is not supported yet!"
q_set = [] q_set = []
for _ in range(n_quantizer_sets): for _ in range(n_quantizer_sets):
...@@ -911,4 +923,4 @@ class QuantizerFactory: ...@@ -911,4 +923,4 @@ class QuantizerFactory:
return q_set[0] if len(q_set) == 1 else tuple(q_set) return q_set[0] if len(q_set) == 1 else tuple(q_set)
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING) noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING, is_2x2x=False)
...@@ -13,18 +13,52 @@ from abc import ABC, abstractmethod ...@@ -13,18 +13,52 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Tuple, Dict from typing import Tuple, Dict
from functools import reduce from functools import reduce, lru_cache
import operator import operator
import numpy as np import numpy as np
from jax.experimental.custom_partitioning import CompoundFactor from jax.experimental.custom_partitioning import BATCHING
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import JAXX_Scaling_Mode from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout
from .device_utils import is_fp8_gemm_with_all_layouts_supported
__all__ = ["QuantizeShardyRules", "ScalingMode"] __all__ = [
"QuantizeShardyRules",
"ScalingMode",
"TensorUsage",
]
class TensorUsage(Enum):
"""Enum indicating tensor usage in GEMM operations.
Given a GEMM operation: C = A * B in which A and B can be in the normal or transposed form.
The tensor usage can be:
- LHS: A is in the normal form
- LHS_TRANS: A is in the transposed form
- RHS: B is in the normal form
- RHS_TRANS: B is in the transposed form
The tensor usage is used in the ScaledTensor.get_tensor() method.
"""
# LHS: Left-hand side, RHS: Right-hand side
# LHS_TRANS: Left-hand side transposed, RHS_TRANS: Right-hand side transposed
LHS = 0
LHS_TRANS = 1
RHS = 2
RHS_TRANS = 3
def __eq__(self, other):
if not isinstance(other, TensorUsage):
return False
return self.value == other.value
def __hash__(self):
return hash(self.value)
def DIVUP(a, b): def DIVUP(a, b):
...@@ -104,6 +138,18 @@ class ScalingModeMetadataImpl(ABC): ...@@ -104,6 +138,18 @@ class ScalingModeMetadataImpl(ABC):
The shape for scale tensors The shape for scale tensors
""" """
@lru_cache(maxsize=4)
@abstractmethod
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
@abstractmethod @abstractmethod
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self, input_rank, unique_var, flatten_axis
...@@ -157,6 +203,23 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -157,6 +203,23 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (0,) return (0,)
return (1,) return (1,)
@lru_cache(maxsize=4)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
if is_fp8_gemm_with_all_layouts_supported():
return QuantizeLayout.ROWWISE
if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS):
return QuantizeLayout.ROWWISE
return QuantizeLayout.COLWISE
def get_grouped_scale_shape( def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]: ) -> Tuple[int]:
...@@ -189,8 +252,9 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -189,8 +252,9 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
del flatten_axis del flatten_axis
input_spec = tuple(f"x{i}" for i in range(input_rank)) input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank))
return QuantizeShardyRules(input_spec, (unique_var,), (unique_var,), {}) scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl): class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl):
...@@ -321,6 +385,27 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -321,6 +385,27 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (*first_dim_scale_shape, *last_dim_scale_shape) return (*first_dim_scale_shape, *last_dim_scale_shape)
@lru_cache(maxsize=4)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
# If we need to support 1x1x for inference in the future
# if QuantizeConfig.INFERENCE_MODE:
# assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!")
# if usage == TensorUsage.LHS:
# return QuantizeLayout.ROWWISE
# return QuantizeLayout.COLWISE
if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS):
return QuantizeLayout.ROWWISE
return QuantizeLayout.COLWISE
def get_grouped_scale_shape( def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]: ) -> Tuple[int]:
...@@ -404,31 +489,41 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -404,31 +489,41 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
input_spec = [f"x{i}" for i in range(input_rank)] del flatten_axis
input_spec = [f"{unique_var}{i}" for i in range(input_rank)]
# We have to use two different factors in the two CompoundFactors because of Shardy rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)]
# verifier requirements, even though they are the same. colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)]
rowwise_var = unique_var
colwise_var = f"{unique_var}_" # NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors.
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") # Unfortunately, because Shardy rules are applied to the inner primitive, the
input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") # only way to preserve the relationship is to lower unpadded scales to the
# underlying custom call and pad them in C++. Until that's implemented, the
# The rowwise and colwise scale tensors should be sharded the same way as the input. # Shardy rules for block scales have to be completely disconnected from the
# However, we need to adjust the dimensions where the block scaling factor applies. # Shardy rules for the tensor they belong to.
rowwise = input_spec.copy()
rowwise[-1] = rowwise_var # # We have to use two different factors in the two CompoundFactors because of Shardy
# # verifier requirements, even though they are the same.
colwise = input_spec.copy() # rowwise_var = unique_var
colwise[flatten_axis - 1] = colwise_var # colwise_var = f"{unique_var}_"
# input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise")
# This implementation needs to be updated for different block dims. # input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise")
assert self._block_dims == (1, 32)
# # The rowwise and colwise scale tensors should be sharded the same way as the input.
# # However, we need to adjust the dimensions where the block scaling factor applies.
# rowwise = input_spec.copy()
# rowwise[-1] = rowwise_var
# colwise = input_spec.copy()
# colwise[flatten_axis - 1] = colwise_var
# # This implementation needs to be updated for different block dims.
# assert self._block_dims == (1, 32)
return QuantizeShardyRules( return QuantizeShardyRules(
tuple(input_spec), tuple(input_spec),
tuple(rowwise), tuple(rowwise),
tuple(colwise), tuple(colwise),
{"block_size_rowwise": 32, "block_size_colwise": 32}, {}, # {"block_size_rowwise": 32, "block_size_colwise": 32},
) )
...@@ -506,6 +601,17 @@ class ScalingMode(Enum): ...@@ -506,6 +601,17 @@ class ScalingMode(Enum):
""" """
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis) return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
return self._get_impl().get_quantize_layout(usage)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis=-1 self, input_rank, unique_var, flatten_axis=-1
) -> Tuple[Tuple[str]]: ) -> Tuple[Tuple[str]]:
......
...@@ -17,13 +17,14 @@ from jax.tree_util import register_pytree_node_class ...@@ -17,13 +17,14 @@ from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode, TensorUsage
from .dequantizer import ScalingModeToDequantizerMap from .dequantizer import ScalingModeToDequantizerMap
from ..sharding import ( from ..sharding import (
with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
) )
__all__ = [ __all__ = [
"TensorUsage",
"ScaledTensor", "ScaledTensor",
"ScaledTensor1x", "ScaledTensor1x",
"ScaledTensor2x", "ScaledTensor2x",
...@@ -55,6 +56,11 @@ class ScaledTensor(ABC): ...@@ -55,6 +56,11 @@ class ScaledTensor(ABC):
""" """
return cls(*children, *aux_data) return cls(*children, *aux_data)
@property
@abstractmethod
def ndim(self):
"""Number of dimensions of the underlying quantized array."""
@abstractmethod @abstractmethod
def dequantize(self): def dequantize(self):
"""Dequantizes the tensor back to its original precision. """Dequantizes the tensor back to its original precision.
...@@ -64,25 +70,15 @@ class ScaledTensor(ABC): ...@@ -64,25 +70,15 @@ class ScaledTensor(ABC):
""" """
@abstractmethod @abstractmethod
def get_rowwise_tensor(self): def get_tensor(self, usage: TensorUsage):
"""Returns the row-wise component of the tensor. """Returns the appropriate tensor based on the tensor usage and the scaling mode.
If the tensor usage is not valid for the scaling mode, an error is raised.
Returns:
The row-wise tensor component
Raises: Args:
ValueError: If called on a tensor that doesn't support row-wise access usage: The usage of the tensor
"""
@abstractmethod
def get_colwise_tensor(self):
"""Returns the column-wise component of the tensor.
Returns: Returns:
The column-wise tensor component The tensor based on the usage
Raises:
ValueError: If called on a tensor that doesn't support column-wise access
""" """
@abstractmethod @abstractmethod
...@@ -136,24 +132,18 @@ class ScaledTensor1x(ScaledTensor): ...@@ -136,24 +132,18 @@ class ScaledTensor1x(ScaledTensor):
0 < self.flatten_axis < len(self.data.shape) 0 < self.flatten_axis < len(self.data.shape)
), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}"
expected_scale_shape = self.scaling_mode.get_scale_shape( if self.scaling_mode == ScalingMode.NO_SCALING:
self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis self.scale_inv = jnp.empty((0,), dtype=jnp.float32)
) else:
expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis self.data.shape,
) is_colwise=self.is_colwise,
if self.scale_inv.shape != expected_scale_shape: is_padded=False,
assert self.scale_inv.shape == expected_unpadded_scale_shape, ( flatten_axis=self.flatten_axis,
f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded"
f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got"
f" {self.scale_inv.shape}"
)
pad_width = tuple(
(0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape)
) )
# This actually pad scale_inv with nan, should we pad it with 127 directly instead? assert self.scale_inv.shape == unpadded_scale_shape, (
self.scale_inv = jnp.pad( "Unpadded inverse scale factor has wrong shape, expected"
self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0 f" {unpadded_scale_shape} but got {self.scale_inv.shape}."
) )
def tree_flatten(self): def tree_flatten(self):
...@@ -173,6 +163,10 @@ class ScaledTensor1x(ScaledTensor): ...@@ -173,6 +163,10 @@ class ScaledTensor1x(ScaledTensor):
) )
return (children, aux_data) return (children, aux_data)
@property
def ndim(self):
return self.data.ndim
def dequantize(self): def dequantize(self):
"""Dequantizes the tensor using the stored dequantization function. """Dequantizes the tensor using the stored dequantization function.
...@@ -181,33 +175,19 @@ class ScaledTensor1x(ScaledTensor): ...@@ -181,33 +175,19 @@ class ScaledTensor1x(ScaledTensor):
""" """
return self._dq_func(self) return self._dq_func(self)
def get_rowwise_tensor(self): def get_tensor(self, usage: TensorUsage):
"""Returns the tensor if it's row-wise quantized. """Returns the tensor based on the tensor usage."""
q_layout = self.scaling_mode.get_quantize_layout(usage)
Returns: colwise_usage_valid = q_layout == QuantizeLayout.COLWISE and self.is_colwise
The row-wise tensor rowwise_usage_valid = q_layout == QuantizeLayout.ROWWISE and not self.is_colwise
Raises: if colwise_usage_valid or rowwise_usage_valid:
ValueError: If called on a column-wise quantized tensor
"""
if not self.is_colwise:
return self
raise ValueError("Calling get_rowwise_tensor() from a colwise ScaledTensor1x!")
def get_colwise_tensor(self):
"""Returns the tensor if it's column-wise quantized.
Returns:
The column-wise tensor
Raises:
ValueError: If called on a row-wise quantized tensor
"""
if self.is_colwise:
return self return self
raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!") raise ValueError(
f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
f" self.is_colwise={self.is_colwise}!"
)
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names. """Applies sharding constraints to a tensor based on logical axis names.
...@@ -370,6 +350,11 @@ class ScaledTensor2x(ScaledTensor): ...@@ -370,6 +350,11 @@ class ScaledTensor2x(ScaledTensor):
aux_data = () aux_data = ()
return (children, aux_data) return (children, aux_data)
@property
def ndim(self):
"""Number of dimensions of the underlying row-wise tensor."""
return self.rowwise_tensor.ndim
def dequantize(self): def dequantize(self):
"""Dequantizes the tensor using the row-wise component's dequantization. """Dequantizes the tensor using the row-wise component's dequantization.
...@@ -378,21 +363,21 @@ class ScaledTensor2x(ScaledTensor): ...@@ -378,21 +363,21 @@ class ScaledTensor2x(ScaledTensor):
""" """
return self.rowwise_tensor.dequantize() return self.rowwise_tensor.dequantize()
def get_rowwise_tensor(self): def get_tensor(self, usage: TensorUsage):
"""Returns the row-wise quantized component. """Returns the tensor based on the tensor usage."""
q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage)
q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage)
Returns: if q_layout_rowwise == QuantizeLayout.ROWWISE:
The row-wise tensor component return self.rowwise_tensor
"""
return self.rowwise_tensor
def get_colwise_tensor(self): if q_layout_colwise == QuantizeLayout.COLWISE:
"""Returns the column-wise quantized component. return self.colwise_tensor
Returns: raise ValueError(
The column-wise tensor component f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
""" f" q_layout_rowwise={q_layout_rowwise} and q_layout_colwise={q_layout_colwise}!"
return self.colwise_tensor )
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names. """Applies sharding constraints to a tensor based on logical axis names.
......
...@@ -14,6 +14,7 @@ from contextlib import contextmanager ...@@ -14,6 +14,7 @@ from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
import warnings
from jax.interpreters import pxla from jax.interpreters import pxla
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -117,7 +118,9 @@ def with_sharding_constraint_by_logical_axes( ...@@ -117,7 +118,9 @@ def with_sharding_constraint_by_logical_axes(
x: jnp.array, logical_axis_names: Optional[tuple | list] x: jnp.array, logical_axis_names: Optional[tuple | list]
): ):
""" """
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes. A wrapper function to flax.linen.with_logical_constraint.
DEPRECATED USE CASE: If no Flax logical axis rules are available, this function falls back to jax.lax.with_sharding_constraint using a hardcoded logical axis rule table from TE rules, such as BATCH_AXES. This functionality will be removed in the future.
If logical_axis_names = None, this means no sharding constraint is applied. If logical_axis_names = None, this means no sharding constraint is applied.
...@@ -133,6 +136,28 @@ def with_sharding_constraint_by_logical_axes( ...@@ -133,6 +136,28 @@ def with_sharding_constraint_by_logical_axes(
if not logical_axis_names: if not logical_axis_names:
return x return x
try:
# Check if Flax logical axis rules are available, if so use them
import flax
flax_rules = flax.linen.get_logical_axis_rules()
if len(flax_rules) > 0:
return flax.linen.with_logical_constraint(
x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT
)
except ImportError:
pass
warnings.warn(
"TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and"
" will be removed in a future version. Please use Flax logical axes with a"
" flax.linen.logical_axis_rules context and optionally use"
" transformer_engine.jax.flax.extend_logical_axis_rules to add BATCH_AXES, etc. to your"
" rules.",
DeprecationWarning,
)
# If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table
assert len(x.shape) == len(logical_axis_names) assert len(x.shape) == len(logical_axis_names)
pspec = generate_pspec(logical_axis_names) pspec = generate_pspec(logical_axis_names)
return with_sharding_constraint(x, pspec) return with_sharding_constraint(x, pspec)
......
...@@ -53,6 +53,7 @@ from transformer_engine.pytorch.distributed import CudaRNGStatesTracker ...@@ -53,6 +53,7 @@ from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
from transformer_engine.pytorch import ops from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers from transformer_engine.pytorch import optimizers
from transformer_engine.pytorch.export import onnx_export
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
try: try:
......
...@@ -57,6 +57,8 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import ( ...@@ -57,6 +57,8 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
from transformer_engine.pytorch.attention.dot_product_attention.utils import ( from transformer_engine.pytorch.attention.dot_product_attention.utils import (
AttentionLogging as attn_log, AttentionLogging as attn_log,
) )
from transformer_engine.pytorch import export
from transformer_engine.pytorch.export import is_in_onnx_export_mode
# Global vars for flash attn v2 and v3 imports # Global vars for flash attn v2 and v3 imports
flash_attn_cuda_bwd = None flash_attn_cuda_bwd = None
...@@ -150,7 +152,14 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -150,7 +152,14 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number self.layer_number = layer_number
self.scale_mask_softmax = FusedScaleMaskSoftmax(attention_mask_func) def mask_func(x, y):
return (
export.onnx_attention_mask_func(x, y)
if is_in_onnx_export_mode()
else attention_mask_func(x, y)
)
self.scale_mask_softmax = FusedScaleMaskSoftmax(mask_func)
# Dropout. Note that for a single iteration, this layer will generate # Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but # different outputs on different number of parallel partitions but
......
...@@ -2559,8 +2559,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2559,8 +2559,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.enable_mla: if ctx.enable_mla:
# [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn] # [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn]
dk_fp8 = dkv_fp8[: ctx.k_numel].view(cp_size, *ctx.k_shape) dk_fp8 = dkv_fp8[:, : ctx.k_numel].view(cp_size, *ctx.k_shape)
dv_fp8 = dkv_fp8[ctx.k_numel :].view(cp_size, *ctx.v_shape) dv_fp8 = dkv_fp8[:, ctx.k_numel :].view(cp_size, *ctx.v_shape)
dk = ctx.dQKV_CP_quantizer.create_tensor_from_data( dk = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dk_fp8, fake_dtype=torch.float32, internal=True dk_fp8, fake_dtype=torch.float32, internal=True
) )
...@@ -2586,8 +2586,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2586,8 +2586,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
if ctx.enable_mla: if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn] # [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
dk = dk.view(*dk.shape[0], -1, *dk.shape[-2:]) dk = dk.view(dk.shape[0], -1, *dk.shape[-2:])
dv = dv.view(*dv.shape[0], -1, *dv.shape[-2:]) dv = dv.view(dv.shape[0], -1, *dv.shape[-2:])
else: else:
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
......
...@@ -17,6 +17,7 @@ from transformer_engine.pytorch.utils import get_cudnn_version ...@@ -17,6 +17,7 @@ from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.constants import ( from transformer_engine.pytorch.constants import (
AttnMaskTypes, AttnMaskTypes,
AttnTypes, AttnTypes,
...@@ -963,47 +964,54 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -963,47 +964,54 @@ class DotProductAttention(TransformerEngineBaseModule):
inference_params=inference_params, inference_params=inference_params,
) )
global _attention_backends global _attention_backends
if ( if is_in_onnx_export_mode():
_attention_backends["attention_params"] is None # We do not want to call get_attention_backend() in ONNX mode
or attention_params != _attention_backends["attention_params"] # and we want to avoid using any global variables like _attention_backends.
): use_flash_attention = False
_attention_backends["attention_params"] = attention_params use_fused_attention = False
_attention_backends["backend_selection_requires_update"] = True use_unfused_attention = True
if _attention_backends["backend_selection_requires_update"]:
(
use_flash_attention,
flash_attention_backend,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
_,
) = dpa_utils.get_attention_backend(attention_params)
# Set global _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
if use_flash_attention:
self.logger.info(
"Running with FlashAttention backend (version %s)",
flash_attention_backend,
)
elif use_fused_attention:
self.logger.info(
"Running with FusedAttention backend (sub-backend %s)",
int(fused_attention_backend),
)
elif use_unfused_attention:
self.logger.info("Running with UnfusedDotProductAttention backend")
else: else:
use_flash_attention = _attention_backends["use_flash_attention"] if (
flash_attention_backend = _attention_backends["flash_attention_backend"] _attention_backends["attention_params"] is None
use_fused_attention = _attention_backends["use_fused_attention"] or attention_params != _attention_backends["attention_params"]
fused_attention_backend = _attention_backends["fused_attention_backend"] ):
use_unfused_attention = _attention_backends["use_unfused_attention"] _attention_backends["attention_params"] = attention_params
_attention_backends["backend_selection_requires_update"] = True
if _attention_backends["backend_selection_requires_update"]:
(
use_flash_attention,
flash_attention_backend,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
_,
) = dpa_utils.get_attention_backend(attention_params)
# Set global _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
if use_flash_attention:
self.logger.info(
"Running with FlashAttention backend (version %s)",
flash_attention_backend,
)
elif use_fused_attention:
self.logger.info(
"Running with FusedAttention backend (sub-backend %s)",
int(fused_attention_backend),
)
elif use_unfused_attention:
self.logger.info("Running with UnfusedDotProductAttention backend")
else:
use_flash_attention = _attention_backends["use_flash_attention"]
flash_attention_backend = _attention_backends["flash_attention_backend"]
use_fused_attention = _attention_backends["use_fused_attention"]
fused_attention_backend = _attention_backends["fused_attention_backend"]
use_unfused_attention = _attention_backends["use_unfused_attention"]
# raise exception if no backend is available # raise exception if no backend is available
if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0: if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0:
......
...@@ -8,6 +8,7 @@ from typing import Callable, Tuple, Union, Optional ...@@ -8,6 +8,7 @@ from typing import Callable, Tuple, Union, Optional
import torch import torch
from torch import nn from torch import nn
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode
THREADS_PER_WARP = 32 THREADS_PER_WARP = 32
...@@ -19,12 +20,18 @@ _default_causal_mask = {} ...@@ -19,12 +20,18 @@ _default_causal_mask = {}
def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor: def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input""" """Return the causal upper triangular mask for softmax input"""
matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask: def _get_mask():
diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1 diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1
_default_causal_mask[matrix_identifiers] = torch.triu( return torch.triu(
torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset
) )
if is_in_onnx_export_mode():
return _get_mask()
matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask:
_default_causal_mask[matrix_identifiers] = _get_mask()
return _default_causal_mask[matrix_identifiers] return _default_causal_mask[matrix_identifiers]
...@@ -169,7 +176,11 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -169,7 +176,11 @@ class FusedScaleMaskSoftmax(nn.Module):
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled" assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled"
if is_in_onnx_export_mode():
return self.forward_torch_softmax(inp, mask, scale)
# We do not want to connect this if with previous if,
# because we want to avoid calling is_kernel_available() in ONNX mode.
if self.is_kernel_available(mask, *inp.size()): if self.is_kernel_available(mask, *inp.size()):
return self.forward_fused_softmax(inp, mask, scale) return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale) return self.forward_torch_softmax(inp, mask, scale)
...@@ -245,15 +256,15 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -245,15 +256,15 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.attn_mask_type in ["causal", "causal_bottom_right"]: if self.attn_mask_type in ["causal", "causal_bottom_right"]:
seq_len_q, seq_len_k = inp.size(2), inp.size(3) seq_len_q, seq_len_k = inp.size(2), inp.size(3)
causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k)
if mask is None: if mask is None:
mask = causal_mask mask = causal_mask
else: else:
mask = torch.logical_or(mask, causal_mask) mask = torch.logical_or(mask, causal_mask)
mask_output = inp mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask": if mask is not None and self.attn_mask_type != "no_mask":
mask_output = self.mask_func(inp, mask) mask_output = self.mask_func(inp, mask)
probs = torch.nn.Softmax(dim=-1)(mask_output) probs = torch.nn.functional.softmax(mask_output, dim=-1)
if self.input_in_float16 and self.softmax_in_fp32: if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16: if self.input_in_fp16:
......
...@@ -44,6 +44,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -44,6 +44,7 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability, get_device_compute_capability,
get_cudnn_version, get_cudnn_version,
) )
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.jit import jit_fuser
...@@ -105,7 +106,7 @@ class FlashAttentionUtils: ...@@ -105,7 +106,7 @@ class FlashAttentionUtils:
version = PkgVersion("0") version = PkgVersion("0")
version_required = PkgVersion("2.1.1") version_required = PkgVersion("2.1.1")
version_required_blackwell = PkgVersion("2.7.3") version_required_blackwell = PkgVersion("2.7.3")
max_version = PkgVersion("2.7.4.post1") max_version = PkgVersion("2.8.1")
v2_plus = False v2_plus = False
v2_1_plus = False v2_1_plus = False
v2_3_plus = False v2_3_plus = False
...@@ -437,8 +438,8 @@ def get_attention_backend( ...@@ -437,8 +438,8 @@ def get_attention_backend(
# | FP8 | non-paged/paged | sm90 | thd | >= 1 # | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if inference_params is not None: if inference_params is not None:
if device_compute_capability == (8, 9) and cudnn_version < (9, 11, 0): if device_compute_capability == (8, 9) and cudnn_version < (9, 12, 0):
logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.11") logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.12")
use_fused_attention = False use_fused_attention = False
if context_parallel: if context_parallel:
logger.debug("Disabling all backends for KV caching with context parallelism") logger.debug("Disabling all backends for KV caching with context parallelism")
...@@ -624,6 +625,12 @@ def get_attention_backend( ...@@ -624,6 +625,12 @@ def get_attention_backend(
" bias for THD format" " bias for THD format"
) )
use_fused_attention = False use_fused_attention = False
elif fp8 and head_dim_qk != head_dim_v:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with FP8"
" MLA attention"
)
use_fused_attention = False
# Filter: Attention mask # Filter: Attention mask
# attn_mask_type | attention_mask | supported backends # attn_mask_type | attention_mask | supported backends
...@@ -1150,9 +1157,7 @@ def get_full_mask( ...@@ -1150,9 +1157,7 @@ def get_full_mask(
swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q + window_size[1] actual_seqlens_kv - actual_seqlens_q + window_size[1]
).view(batch_size, 1, 1, 1) ).view(batch_size, 1, 1, 1)
swa_mask = torch.logical_not( swa_mask = torch.logical_not((swa_left <= 0) & ~(swa_right < 0))
torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0)
)
if attention_mask is not None: if attention_mask is not None:
attention_mask = torch.logical_or(swa_mask, attention_mask) attention_mask = torch.logical_or(swa_mask, attention_mask)
else: else:
...@@ -1343,14 +1348,22 @@ def get_full_cu_seqlens( ...@@ -1343,14 +1348,22 @@ def get_full_cu_seqlens(
""" """
global _cu_seqlens_cache global _cu_seqlens_cache
if (batch_size, max_seqlen) not in _cu_seqlens_cache:
_cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange( def _get_cu_seqlens(batch_size, max_seqlen, device):
return torch.arange(
0, 0,
(batch_size + 1) * max_seqlen, (batch_size + 1) * max_seqlen,
step=max_seqlen, step=max_seqlen,
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
if is_in_onnx_export_mode():
return _get_cu_seqlens(batch_size, max_seqlen, device)
if (batch_size, max_seqlen) not in _cu_seqlens_cache:
_cu_seqlens_cache[(batch_size, max_seqlen)] = _get_cu_seqlens(
batch_size, max_seqlen, device
)
return _cu_seqlens_cache[(batch_size, max_seqlen)] return _cu_seqlens_cache[(batch_size, max_seqlen)]
...@@ -1626,11 +1639,16 @@ def get_qkv_layout( ...@@ -1626,11 +1639,16 @@ def get_qkv_layout(
def run_iteratively(q, k, v): def run_iteratively(q, k, v):
# check data pointers # check data pointers
data_ptr = q.untyped_storage().data_ptr() if is_in_onnx_export_mode():
check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) check_ptrs_qkv = False
check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k]) check_ptrs_qk = False
data_ptr = k.untyped_storage().data_ptr() check_ptrs_kv = False
check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v]) else:
data_ptr = q.untyped_storage().data_ptr()
check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k])
data_ptr = k.untyped_storage().data_ptr()
check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])
# check tensor shapes # check tensor shapes
shape = q.shape shape = q.shape
...@@ -1718,7 +1736,10 @@ def get_qkv_layout( ...@@ -1718,7 +1736,10 @@ def get_qkv_layout(
return qkv_layout return qkv_layout
qkv_layout = run_iteratively(q, k, v) if not is_in_onnx_export_mode():
qkv_layout = run_iteratively(q, k, v)
else:
qkv_layout = "not_supported"
if qkv_layout == "not_supported": if qkv_layout == "not_supported":
# force q,k,v to be contiguous and run get_layout again # force q,k,v to be contiguous and run get_layout again
q, k, v = [x.contiguous() for x in [q, k, v]] q, k, v = [x.contiguous() for x in [q, k, v]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment