Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Dense layer transformation operations for Transformer Engine in JAX. """Dense layer transformation operations for Transformer Engine in JAX.
...@@ -21,12 +21,12 @@ from .quantize import ( ...@@ -21,12 +21,12 @@ from .quantize import (
ScaledTensorFactory, ScaledTensorFactory,
ScaledTensor, ScaledTensor,
ScalingMode, ScalingMode,
QuantizeLayout,
QuantizerSet, QuantizerSet,
noop_quantizer_set, noop_quantizer_set,
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
TensorUsage, TensorUsage,
QuantizeLayout,
) )
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for JAX""" """Transformer Engine bindings for JAX"""
from .module import DenseGeneral, LayerNorm from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP from .module import LayerNormDenseGeneral, LayerNormMLP
from .module import wrap_function_in_te_state_module, make_dot_general_cls
from .transformer import extend_logical_axis_rules from .transformer import extend_logical_axis_rules
from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType from .transformer import TransformerLayer, TransformerLayerType
...@@ -13,6 +14,8 @@ __all__ = [ ...@@ -13,6 +14,8 @@ __all__ = [
"LayerNorm", "LayerNorm",
"LayerNormDenseGeneral", "LayerNormDenseGeneral",
"LayerNormMLP", "LayerNormMLP",
"wrap_function_in_te_state_module",
"make_dot_general_cls",
"extend_logical_axis_rules", "extend_logical_axis_rules",
"DotProductAttention", "DotProductAttention",
"MultiHeadAttention", "MultiHeadAttention",
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
""" """
...@@ -7,6 +7,7 @@ Wrapper module for Transformer related layers with FP8 support. ...@@ -7,6 +7,7 @@ Wrapper module for Transformer related layers with FP8 support.
from functools import reduce from functools import reduce
import operator import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional
import warnings
import numpy as np import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
...@@ -23,8 +24,9 @@ from ..layernorm import layernorm ...@@ -23,8 +24,9 @@ from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp from ..layernorm_mlp import layernorm_mlp
from ..activation import activation from ..activation import activation
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxFusionType
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
from ..attention import AttnSoftmaxType
from ..cpp_extensions import ( from ..cpp_extensions import (
is_softmax_kernel_available, is_softmax_kernel_available,
jax_scaled_softmax, jax_scaled_softmax,
...@@ -171,15 +173,20 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods ...@@ -171,15 +173,20 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
---------- ----------
scale_factor : float, default = 1.0 scale_factor : float, default = 1.0
Scalar for the input to softmax. Scalar for the input to softmax.
softmax_type : SoftmaxType, default = SoftmaxType.SCALED softmax_fusion_type : SoftmaxFusionType, default = SoftmaxFusionType.SCALED
Indicate the type of softmax.
softmax_type : AttnSoftmaxType, default = AttnSoftmaxType.VANILLA_SOFTMAX
Indicate the type of softmax. Indicate the type of softmax.
""" """
scale_factor: float = 1.0 scale_factor: float = 1.0
softmax_type: SoftmaxType = SoftmaxType.SCALED softmax_fusion_type: SoftmaxFusionType = SoftmaxFusionType.SCALED
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
@nn.compact @nn.compact
def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray: def __call__(
self, inputs: Array, mask: Array = None, bias: Array = None, softmax_offset: Array = None
) -> jnp.ndarray:
batch = inputs.shape[0] batch = inputs.shape[0]
heads = inputs.shape[1] heads = inputs.shape[1]
q_seqlen = inputs.shape[2] q_seqlen = inputs.shape[2]
...@@ -187,33 +194,52 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods ...@@ -187,33 +194,52 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
input_dtype = inputs.dtype input_dtype = inputs.dtype
logits = inputs logits = inputs
if softmax_offset is not None:
assert self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX
if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
softmax_offset = 0.0
# use primitives # use primitives
if is_softmax_kernel_available( if is_softmax_kernel_available(
self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype self.softmax_fusion_type,
self.softmax_type,
batch,
heads,
q_seqlen,
k_seqlen,
input_dtype,
): ):
if bias is not None: if bias is not None:
logits = logits + bias.astype(input_dtype) logits = logits + bias.astype(input_dtype)
mask_ = mask mask_ = mask
if self.softmax_type is not SoftmaxType.SCALED_MASKED: if self.softmax_fusion_type is not SoftmaxFusionType.SCALED_MASKED:
mask_ = None mask_ = None
outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type) outputs = softmax(logits, mask_, self.scale_factor, self.softmax_fusion_type)
# use default jax based implementation # use default jax based implementation
else: else:
warnings.warn(
"Using unfused JAX softmax implementation instead of TE fused primitives. ",
UserWarning,
stacklevel=2,
)
if bias is not None: if bias is not None:
logits = logits + bias.astype(input_dtype) logits = logits + bias.astype(input_dtype)
if self.softmax_type is SoftmaxType.SCALED: if self.softmax_fusion_type is SoftmaxFusionType.SCALED:
outputs = jax_scaled_softmax(logits, self.scale_factor) outputs = jax_scaled_softmax(logits, self.scale_factor, softmax_offset)
elif self.softmax_type is SoftmaxType.SCALED_MASKED: elif self.softmax_fusion_type is SoftmaxFusionType.SCALED_MASKED:
outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor) outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor, softmax_offset)
elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: elif self.softmax_fusion_type is SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor) outputs = jax_scaled_upper_triang_masked_softmax(
logits, self.scale_factor, softmax_offset
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED," f"Unsupported softmax fusion: {self.softmax_fusion_type}. softmax_fusion_type"
" SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]" " must be [SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
) )
assert input_dtype == outputs.dtype assert input_dtype == outputs.dtype
return outputs return outputs
...@@ -253,26 +279,26 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -253,26 +279,26 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization. Indicate the type of layer normalization.
zero_centered_gamma : bool, default = False zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to If set to ``True``, the LayerNorm formula changes to
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta (1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'. This parameter is only applicable for ``'layernorm'``.
The default of `scale_init` will also be changed. See `scale_init`. The default of ``scale_init`` will also be changed. See ``scale_init``.
scale_init : Initializer, default = None scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma. If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`. If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is `flax.linen.initializers.ones`. Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh. The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
bias_init : Initializer, default = flax.linen.initializers.zeros bias_init : Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`, Used for initializing shift factors :math:`\beta`,
only used when :attr:`layernorm_type='layernorm'`. only used when :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes : Tuple[str, ...], default = ('embed', ) bias_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
only used when :attr:`layernorm_type='layernorm'`. only used when :attr:`layernorm_type='layernorm'`.
...@@ -398,15 +424,15 @@ class DenseGeneral(TransformerEngineBase): ...@@ -398,15 +424,15 @@ class DenseGeneral(TransformerEngineBase):
kernel_init : Initializer, default = kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights. Used for initializing weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes : Tuple[str, ...], default = () kernel_axes : Tuple[str, ...], default = ()
The name of axes used to shard the weights with a corresponding mesh. The name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting. Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias. If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`. Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes: Tuple[str, ...], default = () bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh, The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`. only used when :attr:`use_bias=True`.
...@@ -417,12 +443,12 @@ class DenseGeneral(TransformerEngineBase): ...@@ -417,12 +443,12 @@ class DenseGeneral(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True` :attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output. The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
input_axes: Tuple[str, ...], default = None input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input, like Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
Optimization parameters Optimization parameters
...@@ -571,48 +597,48 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -571,48 +597,48 @@ class LayerNormDenseGeneral(TransformerEngineBase):
epsilon : float, default = 1e-6 epsilon : float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to If set to ``True``, the LayerNorm formula changes to
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta (1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'. This parameter is only applicable for ``'layernorm'``.
The default of `scale_init` will also be changed. See `scale_init` The default of ``scale_init`` will also be changed. See ``scale_init``
scale_init : Initializer, default = None scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma. If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`. If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is `flax.linen.initializers.ones`. Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh, The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only used when :attr:`enable_layernorm=True`. only used when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros ln_bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`, Used for initializing shift factors :math:`\beta`,
only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
ln_bias_axes: Tuple[str, ...], default = ('embed', ) ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default = kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights. Used for initializing weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes : Tuple[str, ...], default = () kernel_axes : Tuple[str, ...], default = ()
The name of axes used to shard the weights with a corresponding mesh. The name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting. Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias. If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`. Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes: Tuple[str, ...], default = () bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh, The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`. only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = False return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set ``False``, return ``None`` as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each dense layer. Indicate whether to enable low rank adaptation for each dense layer.
low_rank_adaptation_dim: int, default = 32 low_rank_adaptation_dim: int, default = 32
...@@ -620,16 +646,16 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -620,16 +646,16 @@ class LayerNormDenseGeneral(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True` :attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output. The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
dot_input_axes: Tuple[str, ...], default = None dot_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of dot, like Indicate the logical axes of sharding constraint to the input of dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
Optimization parameters Optimization parameters
...@@ -861,34 +887,34 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -861,34 +887,34 @@ class LayerNormMLP(TransformerEngineBase):
epsilon : float, default = 1e-6 epsilon : float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to If set to ``True``, the LayerNorm formula changes to
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta (1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'. This parameter is only applicable for ``'layernorm'``.
The default of `scale_init` will also be changed. See `scale_init`. The default of ``scale_init`` will also be changed. See ``scale_init``.
scale_init : Initializer, default = None scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma. If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`. If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is `flax.linen.initializers.ones`. Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh, The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only used when :attr:`enable_layernorm=True`. only used when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros ln_bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`, Used for initializing shift factors :math:`\beta`,
only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
ln_bias_axes: Tuple[str, ...], default = ('embed', ) ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default = kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing the weights of both dense layer transformations. Used for initializing the weights of both dense layer transformations.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp') kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
The name of axes used to shard the weights with a corresponding mesh for The name of axes used to shard the weights with a corresponding mesh for
the weight of the first dense layer transformation. the weight of the first dense layer transformation.
...@@ -897,10 +923,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -897,10 +923,10 @@ class LayerNormMLP(TransformerEngineBase):
the weight of the second dense layer transformation. the weight of the second dense layer transformation.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting. Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias. If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`. Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes_1: Tuple[str, ...], default = ('mlp',) bias_axes_1: Tuple[str, ...], default = ('mlp',)
The name of axes used to shard bias with a corresponding mesh for The name of axes used to shard bias with a corresponding mesh for
the weight of the first dense layer transformation. the weight of the first dense layer transformation.
...@@ -911,7 +937,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -911,7 +937,7 @@ class LayerNormMLP(TransformerEngineBase):
Only used when :attr:`use_bias=True`. Only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = False return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set ``False``, return ``None`` as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('gelu',) activations: Sequence[Union[str, Callable]], default = ('gelu',)
The sequence of activation functions to apply after the first dense layer transformation. The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
...@@ -932,20 +958,20 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -932,20 +958,20 @@ class LayerNormMLP(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True`. :attr:`enable_low_rank_adaptation=True`.
low_rank_adaptation_alpha: float, default = None low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output. The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
dot_1_input_axes: Tuple[str, ...], default = None dot_1_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 1st dot, like Indicate the logical axes of sharding constraint to the input of 1st dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
dot_2_input_axes: Tuple[str, ...], default = None dot_2_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 2nd dot, like Indicate the logical axes of sharding constraint to the input of 2nd dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
ffn1_ckpt_name: str = "ffn1" ffn1_ckpt_name: str = "ffn1"
Checkpoint name for the output of the first fully-connected layer in the MLP block. Checkpoint name for the output of the first fully-connected layer in the MLP block.
...@@ -1328,3 +1354,87 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1328,3 +1354,87 @@ class LayerNormMLP(TransformerEngineBase):
assert out.dtype == input_dtype assert out.dtype == input_dtype
return out, ln_output # Output, layer_norm_output return out, ln_output # Output, layer_norm_output
def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] = None):
"""Wraps the given function `f` to support TransformerEngine quantization.
This method does a couple things:
1. Wraps the given function in a Flax linen module. This module does not store any Flax parameters
but can store Flax variables for quantizers if required by the recipe.
2. When the wrapper is called, it provides an additional argument to the given function `f`, 'generate_quantizer_set' as the first argument. 'generate_quantizer_set' is a function that can be called to generate a TransformerEngine/JAX quantizer set object used in TransformerEngine/JAX APIs. 'generate_quantizer_set' will generate quantizers based on the recipe of this TransformerEngineQuantizer object.
Args:
f: The function to wrap. The first argument must be 'generate_quantizer_set'.
name: The name of this wrapped operation. If unspecified, will use `f.__name__`.
Returns:
A Flax linen module that wraps the given function.
"""
import transformer_engine.jax as te
class TEWrapper(te.flax.module.TransformerEngineBase):
"""Wrapper Flax module for TransformerEngine quantization support."""
def generate_quantizer_set(self, postfix: str = ""):
OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
return super().generate_quantizer_set(
postfix=postfix,
variable_collection=OVERWRITE_WITH_GRADIENT,
fp8_recipe=quantization_recipe,
)
@nn.compact
def __call__(self, *args, **kwargs):
return f(self.generate_quantizer_set, *args, **kwargs)
TEWrapper.__name__ = f"TEWrapper_{name if name else f.__name__}"
return TEWrapper
def make_dot_general_cls(quantization_recipe):
"""Creates a Flax module class that performs a dot_general operation with the arguments x and kernel using the given quantization recipe.
This is intended for usage when you already have model parameters initialized and sharded for the kernel weights and you want to replace the GEMM implementation with TE's quantized GEMM using a given recipe.
For example,
```
te_dot_general_cls = make_dot_general_cls(DelayedScaling())
dense = nn.Dense(..., dot_general=te_dot_general_cls())
```
If you would like a drop-in replacement for nn.Dense that manages the model weights itself, please use TE's DenseGeneral module.
Args:
quantization_recipe: The quantization recipe to use for the dot_general operation.
Returns:
A Flax module class that performs a dot_general operation with the given quantization recipe.
"""
import transformer_engine.jax as te
from transformer_engine.common.recipe import NVFP4BlockScaling
def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs):
"""Performs a dot_general operation using TransformerEngine with quantization."""
del kwargs # Unused
contracting_dims, batch_dims = dims
assert batch_dims == ((), ()), "Batch dimensions must be empty for TransformerEngine dot."
quantizer_set = generate_quantizer_set()
if isinstance(quantization_recipe, NVFP4BlockScaling):
# NVFP4 RHT requires inputs to be in bfloat16
x = x.astype(jnp.bfloat16)
kernel = kernel.astype(jnp.bfloat16)
return te.dense.dense(
x,
kernel,
contracting_dims=contracting_dims,
quantizer_set=quantizer_set,
)
return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general")
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
""" """
...@@ -23,11 +23,17 @@ from jax.ad_checkpoint import checkpoint_name ...@@ -23,11 +23,17 @@ from jax.ad_checkpoint import checkpoint_name
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor from ..attention import (
AttnBiasType,
AttnMaskType,
AttnSoftmaxType,
QKVLayout,
SequenceDescriptor,
)
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
from ..attention import fused_attn from ..attention import fused_attn
from ..attention import CPStrategy from ..attention import CPStrategy
from ..softmax import SoftmaxType from ..softmax import SoftmaxFusionType
from ..sharding import num_of_devices from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
...@@ -115,11 +121,11 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -115,11 +121,11 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attention_dropout: float = 0.0 attention_dropout: float = 0.0
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
float32_logits: bool = False float32_logits: bool = False
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = False
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -145,6 +151,22 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -145,6 +151,22 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
input_dtype = query.dtype input_dtype = query.dtype
# Infer number of attention heads from query shape
# query shape: [..., h, d] where h is num_attention_heads
num_attention_heads = query.shape[-2]
# Initialize softmax_offset for learnable softmax
# Note: OFF_BY_ONE_SOFTMAX is handled internally by the Softmax module
softmax_offset = None
if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
# For learnable softmax, create a learnable parameter with proper sharding and shape (1, h, 1, 1)
softmax_offset = self.param(
"softmax_offset",
nn.with_logical_partitioning(nn.initializers.zeros, (None, HEAD_AXES, None, None)),
(1, num_attention_heads, 1, 1),
jnp.float32,
)
if self.scale_factor is None: if self.scale_factor is None:
scale_factor = 1.0 / sqrt(query.shape[-1]) scale_factor = 1.0 / sqrt(query.shape[-1])
else: else:
...@@ -213,8 +235,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -213,8 +235,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
new_mask = jnp.where(original_mask == 0, swa_mask, original_mask) new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
return new_mask return new_mask
def convert_to_softmax_type(attn_mask_type, mask): def convert_to_softmax_fusion_type(attn_mask_type, mask):
"""Convert the attn_mask_type to SoftmaxType""" """Convert the attn_mask_type to SoftmaxFusionType"""
# mask is ignored for no_mask and causal_mask without sliding window # mask is ignored for no_mask and causal_mask without sliding window
if attn_mask_type == AttnMaskType.NO_MASK: if attn_mask_type == AttnMaskType.NO_MASK:
mask = None mask = None
...@@ -224,21 +246,23 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -224,21 +246,23 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
mask = apply_swa_mask(mask) mask = apply_swa_mask(mask)
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if mask is not None: if mask is not None:
return SoftmaxType.SCALED_MASKED, mask return SoftmaxFusionType.SCALED_MASKED, mask
if attn_mask_type is AttnMaskType.CAUSAL_MASK: if attn_mask_type is AttnMaskType.CAUSAL_MASK:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask return SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, mask
if attn_mask_type is AttnMaskType.NO_MASK: if attn_mask_type is AttnMaskType.NO_MASK:
return SoftmaxType.SCALED, mask return SoftmaxFusionType.SCALED, mask
raise ValueError( raise ValueError(
f"Unsupported {attn_mask_type=}, supported attn_mask_type=" f"Unsupported {attn_mask_type=}, supported attn_mask_type="
"{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}" "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
) )
softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask) softmax_fusion_type, mask = convert_to_softmax_fusion_type(self.attn_mask_type, mask)
attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)( attn_weights = Softmax(
attn_weights, mask, bias softmax_fusion_type=softmax_fusion_type,
).astype(input_dtype) softmax_type=self.softmax_type,
scale_factor=fused_scale_factor,
)(attn_weights, mask, bias, softmax_offset=softmax_offset).astype(input_dtype)
if is_gqa: if is_gqa:
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
...@@ -269,7 +293,6 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -269,7 +293,6 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attention_dropout: float = 0.0 attention_dropout: float = 0.0
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
...@@ -279,6 +302,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -279,6 +302,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
context_parallel_axis: str = "" context_parallel_axis: str = ""
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT
context_checkpoint_name: str = "context" context_checkpoint_name: str = "context"
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -303,6 +327,17 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -303,6 +327,17 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor = self.scale_factor scale_factor = self.scale_factor
del self.scale_factor del self.scale_factor
num_attention_heads = query.shape[-2]
softmax_offset = None
if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
# For learnable softmax, create a learnable parameter with proper sharding and shape (1, h, 1, 1)
softmax_offset = self.param(
"softmax_offset",
nn.with_logical_partitioning(nn.initializers.zeros, (None, HEAD_AXES, None, None)),
(1, num_attention_heads, 1, 1),
jnp.float32,
)
if self.qkv_layout.is_qkvpacked(): if self.qkv_layout.is_qkvpacked():
"""qkvpacked format, treat """qkvpacked format, treat
query: qkvpacked tensor, shape = [..., 3, h, d] query: qkvpacked tensor, shape = [..., 3, h, d]
...@@ -320,6 +355,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -320,6 +355,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout, qkv_layout=self.qkv_layout,
softmax_type=self.softmax_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
...@@ -329,6 +365,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -329,6 +365,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy, context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name, context_checkpoint_name=self.context_checkpoint_name,
softmax_offset=softmax_offset,
) )
elif self.qkv_layout.is_kvpacked(): elif self.qkv_layout.is_kvpacked():
"""kvpacked format, treat """kvpacked format, treat
...@@ -348,6 +385,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -348,6 +385,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout, qkv_layout=self.qkv_layout,
softmax_type=self.softmax_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
...@@ -357,6 +395,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -357,6 +395,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy, context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name, context_checkpoint_name=self.context_checkpoint_name,
softmax_offset=softmax_offset,
) )
elif self.qkv_layout.is_separate(): elif self.qkv_layout.is_separate():
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
...@@ -371,6 +410,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -371,6 +410,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout, qkv_layout=self.qkv_layout,
softmax_type=self.softmax_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
...@@ -380,6 +420,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -380,6 +420,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy, context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name, context_checkpoint_name=self.context_checkpoint_name,
softmax_offset=softmax_offset,
) )
else: else:
raise ValueError(f"Unsupported {self.qkv_layout=}.") raise ValueError(f"Unsupported {self.qkv_layout=}.")
...@@ -426,7 +467,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -426,7 +467,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
The hidden dimension of each attention head. The hidden dimension of each attention head.
num_attention_heads: int num_attention_heads: int
The number of attention heads. The number of attention heads.
num_gqa_groups: int, default = `None` num_gqa_groups: int, default = None
Number of GQA groups. When `None` is present, it is equal to num_attention_heads. Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_. `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
...@@ -439,32 +480,45 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -439,32 +480,45 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type: str, default = 'causal' attn_mask_type: str, default = 'causal'
This parameter specifies the type of attention mask to be applied during the softmax This parameter specifies the type of attention mask to be applied during the softmax
operation. operation.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
Each described below: Each described below:
* no_mask: No attention mask is applied. This means the attention will consider the * ``no_mask``: No attention mask is applied. This means the attention will consider the
full sequence without any restrictions. full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence. * ``padding``: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
:attr:`__call__` method to specify the padding positions. :attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs, * ``causal``: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it. from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks. * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
|
.. note:: :attr:`mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
|
.. note:: THD format only supports ``'padding'`` or ``'causal_padding'`` mask type.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. |
.. note:: THD format only supports 'padding' or 'causal_padding' mask type. .. table::
:widths: auto
attn_mask_type mask/sequence_descriptor SWA softmax type ================== ============ ========== ==============================
-------------------------------------------------------------------------------------------- attn_mask_type mask/sd SWA softmax type
================== ============ ========== ==============================
no_mask None None SCALED no_mask None None SCALED
causal None None SCALED_UPPER_TRIANG_MASKED causal None None SCALED_UPPER_TRIANG_MASKED
causal None Yes SCALED_MASKED causal None Yes SCALED_MASKED
padding Required Yes/No SCALED_MASKED padding Required Yes/No SCALED_MASKED
padding_causal Required Yes/No SCALED_MASKED padding_causal Required Yes/No SCALED_MASKED
================== ============ ========== ==============================
where sd stands for sequence_descriptor.
attn_bias_type: Optional[str], default = None attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention. Type of the attention bias passed in the attention.
...@@ -501,24 +555,54 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -501,24 +555,54 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
need to apply scale on query, which is to set :attr:`scale_factor=1.`. need to apply scale on query, which is to set :attr:`scale_factor=1.`.
transpose_batch_sequence: bool, default = True TODO(KshitijLakhani): Reset this to bool only with default False arg in TransformerEngine v2.12
transpose_batch_sequence: bool | None, default = None (however, default is forced to False in post_init)
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
window_size: Optional[Tuple[int, int]], default = None window_size: Optional[Tuple[int, int]], default = None
Sliding window size. The default value is no sliding window. Sliding window size. The default value is no sliding window.
max_segments_per_seq: Optional[int], default = 1 max_segments_per_seq: Optional[int], default = 1
The maximum number of segments per sequence, also used for THD format (sequence packing). The maximum number of segments per sequence, also used for THD format (sequence packing).
context_parallel_causal_load_balanced (bool): context_parallel_causal_load_balanced: bool
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis: str
context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING. The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention. context_parallel_strategy: CPStrategy
The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
context_checkpoint_name: str
The name of the context checkpoint in the forward pass of fused attention.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Softmax type as described in the paper
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
* ``'vanilla'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
* ``'off-by-one'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype(deprecated): jax.numpy.dtype, default = None
The data type used to allocate the initial parameters. This dtype is deprecated and will be removed in a future release. DPA will use the dtype of the inputs instead as this module does not have any parameters.
""" """
head_dim: int head_dim: int
...@@ -527,18 +611,48 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -527,18 +611,48 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout: float = 0.0 attention_dropout: float = 0.0
attn_mask_type: AttnMaskType = "causal" attn_mask_type: AttnMaskType = "causal"
attn_bias_type: AttnBiasType = None attn_bias_type: AttnBiasType = None
dtype: DType = jnp.float32 dtype: Optional[DType] = None # Deprecated
dropout_rng_name: str = "dropout" dropout_rng_name: str = "dropout"
float32_logits: bool = False float32_logits: bool = False
qkv_layout: str = "bshd_bshd_bshd" qkv_layout: str = "bshd_bshd_bshd"
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True transpose_batch_sequence: bool | None = None
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
max_segments_per_seq: Optional[int] = 1 max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = "" context_parallel_axis: str = ""
context_parallel_strategy: str = "DEFAULT" context_parallel_strategy: str = "DEFAULT"
context_checkpoint_name: str = "context" context_checkpoint_name: str = "context"
softmax_type: str = "vanilla"
def __post_init__(self):
# TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
# None implies that the user is relying on defaults, hence warn the user and set the new defaults
if self.transpose_batch_sequence is None:
warnings.warn(
"transpose_batch_sequence defaults to False in DotProductAttention starting"
" TransformerEngine v2.10"
)
self.transpose_batch_sequence = False
super().__post_init__()
def _assert_dtypes(self, query: Array, key: Array, value: Array, qkv_layout: QKVLayout):
"""Asserts that the dtypes of query, key, and value dtypes are consistent."""
if qkv_layout.is_qkvpacked():
pass # No need to check dtypes for key and value since it is packed
elif qkv_layout.is_kvpacked():
assert (
key.dtype == query.dtype
), f"Expected kv {key.dtype=} to match query {query.dtype=}."
elif qkv_layout.is_separate():
assert (
key.dtype == query.dtype
), f"Expected key {key.dtype=} to match query {query.dtype=}."
assert (
value.dtype == query.dtype
), f"Expected value {value.dtype=} to match query {query.dtype=}."
else:
raise ValueError(f"Unsupported {qkv_layout=}.")
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -564,7 +678,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -564,7 +678,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
mask: jax.numpy.ndarray, default = None mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input. Boolean tensor used to mask out the attention softmax input.
:attr:`True` means to mask out the corresponding values. :attr:`True` means to mask out the corresponding values.
Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'. Ignored when :attr:`self.attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
bias: jax.numpy.ndarray, default = None bias: jax.numpy.ndarray, default = None
A tensor used to shift attention softmax input. A tensor used to shift attention softmax input.
*: *:
...@@ -595,12 +709,32 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -595,12 +709,32 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_bias_type = AttnBiasType[self.attn_bias_type.upper()] attn_bias_type = AttnBiasType[self.attn_bias_type.upper()]
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type) attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
qkv_layout = QKVLayout[self.qkv_layout.upper()] qkv_layout = QKVLayout[self.qkv_layout.upper()]
softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
del self.attn_bias_type, self.attn_mask_type, self.qkv_layout del self.attn_bias_type, self.attn_mask_type, self.qkv_layout
if attn_bias_type == AttnBiasType.NO_BIAS: if attn_bias_type == AttnBiasType.NO_BIAS:
assert bias is None assert bias is None
else: else:
assert bias is not None assert bias is not None
bias = bias.astype(input_dtype)
self._assert_dtypes(query, key, value, qkv_layout)
if self.dtype is not None:
if self.dtype == input_dtype:
warnings.warn(
"The dtype argument is deprecated and will be removed in a future release."
" DotProductAttention will use the dtype of the inputs instead as this module"
f" does not have any parameters. Module dtype specified {self.dtype=} matches"
" dtype of inputs so behavior is unchanged. Please remove the dtype argument"
" within the next few releases."
)
else:
raise ValueError(
"The DotProductAttention module dtype is deprecated and will be removed in a"
" future release. DotProductAttention will use the dtype of the inputs instead"
" as this module does not have any parameters. Module dtype specified"
f" {self.dtype=} does not match dtype of inputs {input_dtype=}."
)
# Use fused attn (if kernel check below passes) by default # Use fused attn (if kernel check below passes) by default
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1")) enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1"))
...@@ -621,11 +755,13 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -621,11 +755,13 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
has_fused_attn_kernel = is_fused_attn_kernel_available( has_fused_attn_kernel = is_fused_attn_kernel_available(
# This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
not deterministic, not deterministic,
self.dtype, input_dtype,
self.dtype, # self._assert_dtypes enforces Q, K, V, bias to have the same dtype so using input_dtype as kv dtype is sufficient
input_dtype,
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
self.attention_dropout, self.attention_dropout,
self.num_attention_heads, self.num_attention_heads,
self.num_gqa_groups, self.num_gqa_groups,
...@@ -643,7 +779,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -643,7 +779,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
"Fused attention is not enabled because there is no available kernel.\n" "Fused attention is not enabled because there is no available kernel.\n"
"Fall back to the unfused attention.\n" "Fall back to the unfused attention.\n"
"Please try to update the cuDNN and TE to the latest version.\n" "Please try to update the cuDNN and TE to the latest version.\n"
f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n" f"{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
f"{self.attention_dropout=}\n{self.num_attention_heads=}\n" f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n" f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n"
) )
...@@ -697,11 +833,11 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -697,11 +833,11 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
dtype=self.dtype,
float32_logits=self.float32_logits, float32_logits=self.float32_logits,
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size, window_size=self.window_size,
softmax_type=softmax_type,
)( )(
query, query,
key, key,
...@@ -716,7 +852,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -716,7 +852,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
dtype=self.dtype,
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -726,6 +861,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -726,6 +861,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=context_parallel_strategy, context_parallel_strategy=context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name, context_checkpoint_name=self.context_checkpoint_name,
softmax_type=softmax_type,
)( )(
query, query,
key, key,
...@@ -747,7 +883,7 @@ def rotary_pos_emb( ...@@ -747,7 +883,7 @@ def rotary_pos_emb(
): ):
""" """
Rotary Positional Embedding Rotary Positional Embedding
x should be in shape of x should be of shape
[Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
[Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True. [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
""" """
...@@ -885,7 +1021,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -885,7 +1021,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
The hidden dimension of each attention head. The hidden dimension of each attention head.
num_attention_heads: int num_attention_heads: int
The number of attention heads. The number of attention heads.
num_gqa_groups: int, default = `None` num_gqa_groups: int, default = None
Number of GQA groups. When `None` is present, it is equal to num_attention_heads. Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_. `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
...@@ -898,28 +1034,28 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -898,28 +1034,28 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type: str, default = 'causal' attn_mask_type: str, default = 'causal'
This parameter specifies the type of attention mask to be applied during the softmax This parameter specifies the type of attention mask to be applied during the softmax
operation. operation.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
Each described below: Each described below:
* no_mask: No attention mask is applied. This means the attention will consider the * ``no_mask``: No attention mask is applied. This means the attention will consider the
full sequence without any restrictions. full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence. * ``padding``: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
:attr:`__call__` method to specify the padding positions. :attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs, * ``causal``: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it. from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks. * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. .. note:: :attr:`mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
attn_bias_type: Optional[str], default = None attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention. Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. Available options: ``{'no_bias', 'pre_scale_bias', 'post_scale_bias'}``.
When default is present, the type is automatically decided by the MHA's bias parameter. When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used. Where it is ``'post_scale_bias'`` if there is bias. Otherwise ``'no_bias'`` is used.
dropout_rng_name: str, default = 'dropout' dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that is used The key in given RNGs via flax.linen.Module.apply that is used
to generate Dropout masks in the core attention. to generate Dropout masks in the core attention.
...@@ -928,27 +1064,27 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -928,27 +1064,27 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
layernorm_epsilon: float, default = 1e-6 layernorm_epsilon: float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma: bool, default = False zero_centered_gamma: bool, default = False
If set to `True`, the LayerNorm formula changes to If set to ``True``, the LayerNorm formula changes to
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta (1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'. This parameter is only applicable for ``'layernorm'``.
kernel_init: Initializer, default = kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal') ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')``
Used for initializing the QKV and output projection weights. Used for initializing the QKV and output projection weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether or not to enable bias shifting for QKV and output projections. Indicate whether or not to enable bias shifting for QKV and output projections.
If set to False, the layer will not learn additive biases. If set to ``False``, the layer will not learn additive biases.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = ``flax.linen.initializers.zeros``
Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`. Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
input_layernorm: bool, default = True input_layernorm: bool, default = True
If set to False, layer normalization to the input is not applied. If set to ``False``, layer normalization to the input is not applied.
return_layernorm_output: bool, default = False return_layernorm_output: bool, default = False
If set to True, output of layernorm is returned from the forward together with the output If set to ``True``, output of layernorm is returned from the forward together with the output
of the linear transformation. of the linear transformation.
Example use case: residual connection for transformer module is taken post layernorm. Example use case: residual connection for transformer module is taken post layernorm.
enable_rotary_pos_emb: bool, default = False enable_rotary_pos_emb: bool, default = False
...@@ -958,17 +1094,17 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -958,17 +1094,17 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
only used when :attr:`enable_rotary_pos_emb=True` only used when :attr:`enable_rotary_pos_emb=True`
rotary_pos_emb_group_method: str, default = 'consecutive' rotary_pos_emb_group_method: str, default = 'consecutive'
Indicate the method to coupled the coordinates. It should be one of Indicate the method to coupled the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2` ``['consecutive', 'alternate']``. ``'alternate'`` is to pair index :math:`i` with :math:`i + d/2`
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`. , d is the hidden dimension. ``'consecutive'`` pairs index :math:`i` with :math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none' low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj'] ``['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']``
low_rank_adaptation_dim: int, default = 32 low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True` :attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output. The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
enable_sequence_parallel: bool, default = False enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot. Whether to enable sequence parallelism to operations except dot.
num_heads: int, default = None num_heads: int, default = None
...@@ -988,14 +1124,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -988,14 +1124,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
If set to True, this module exposes a single fused If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for parameter for query-key-value for self-attention and key-value for
cross-attention. cross-attention.
transpose_batch_sequence: bool, default = True TODO(KshitijLakhani): Reset this to bool only with default False arg in TransformerEngine v2.12
transpose_batch_sequence: bool | None, default = None (however, default is forced to False in post_init)
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
scale_attn_logits: bool, default = False scale_attn_logits: bool, default = False
Indicate whether to scale attention logits. Indicate whether to scale attention logits.
If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`, If set to True, :math:`\frac{Q \cdot K^T}{\sqrt{head\_dim}}`,
else :math:`Q*K` else :math:`Q \cdot K^T`
scaled_query_init: bool, default = True scaled_query_init: bool, default = True
Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}` Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
float32_logits: bool, default = False float32_logits: bool, default = False
...@@ -1005,6 +1142,32 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1005,6 +1142,32 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Deprecated. Please refer `fuse_qkv_params` Deprecated. Please refer `fuse_qkv_params`
window_size: Optional[Tuple[int, int]], default = None window_size: Optional[Tuple[int, int]], default = None
Sliding window size. Default value is no sliding window. Sliding window size. Default value is no sliding window.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Softmax type as described in the paper
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
* ``'vanilla'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
* ``'off-by-one'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
""" """
head_dim: int head_dim: int
...@@ -1030,12 +1193,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1030,12 +1193,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True transpose_batch_sequence: bool | None = None
enable_sequence_parallel: bool = False enable_sequence_parallel: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
float32_logits: bool = False float32_logits: bool = False
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
softmax_type: str = "vanilla"
# Deprecated parameters # Deprecated parameters
num_heads: Optional[int] = None num_heads: Optional[int] = None
...@@ -1045,6 +1209,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1045,6 +1209,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
fuse_qkv: Optional[bool] = None fuse_qkv: Optional[bool] = None
def __post_init__(self): def __post_init__(self):
# Deal with changed defaults in API
# TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
# None implies that the user is relying on defaults, hence warn the user and set the new defaults
if self.transpose_batch_sequence is None:
warnings.warn(
"transpose_batch_sequence defaults to False in MultiHeadAttention starting"
" TransformerEngine v2.10"
)
self.transpose_batch_sequence = False
# Deal with the deprecated parameters # Deal with the deprecated parameters
if self.num_heads is not None: if self.num_heads is not None:
self.num_attention_heads = self.num_heads self.num_attention_heads = self.num_heads
...@@ -1109,7 +1282,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1109,7 +1282,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
mask: jax.numpy.ndarray, default = None mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input. Boolean tensor used to mask out the attention softmax input.
:attr:`True` means mask out the corresponding values. :attr:`True` means mask out the corresponding values.
Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'. Ignored when :attr:`self.attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
bias: jax.numpy.ndarray, default = None bias: jax.numpy.ndarray, default = None
A tensor used to shift the attention softmax input. A tensor used to shift the attention softmax input.
* *
...@@ -1433,13 +1606,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1433,13 +1606,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
dtype=self.dtype,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits, float32_logits=self.float32_logits,
qkv_layout=qkv_layout.name, qkv_layout=qkv_layout.name,
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size, window_size=self.window_size,
softmax_type=self.softmax_type,
)(*dpa_args, mask, bias, deterministic=deterministic) )(*dpa_args, mask, bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
...@@ -1594,7 +1767,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1594,7 +1767,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Intermediate size to which input samples are projected. Intermediate size to which input samples are projected.
num_attention_heads: int, default = 8 num_attention_heads: int, default = 8
Number of attention heads in the transformer layer. Number of attention heads in the transformer layer.
num_gqa_groups: int, default = `None` num_gqa_groups: int, default = None
Number of GQA groups. When `None` is present, it is equal to num_attention_heads. Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_. `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
...@@ -1628,31 +1801,31 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1628,31 +1801,31 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
The key in given RNGs via flax.linen.Module.apply that for The key in given RNGs via flax.linen.Module.apply that for
generating Dropout masks in the Multi-Head Attention. generating Dropout masks in the Multi-Head Attention.
mha_kernel_init: Initializer, default = mha_kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal') ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')``
Used for initializing weights of QKV and Output projection weights. Used for initializing weights of QKV and Output projection weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
mlp_kernel_init: Initializer, default = mlp_kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')``
Used for initializing weights of FC1 and FC2 layers. Used for initializing weights of FC1 and FC2 layers.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
mlp_activations: Sequence[str], default = ('gelu', ) mlp_activations: Sequence[str], default = ('gelu', )
The sequence of activation functions to apply after the first linear transformation. The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
mlp_activation_params: dict = None mlp_activation_params: dict = None
This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment This is only used when ``('clamped_silu', 'clamped_linear')`` is in :attr:`mlp_activations`. At the moment
ClampedSwiglu is the only activation that requires parameters. ``ClampedSwiglu`` is the only activation that requires parameters.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
If set to False, the layer will not learn additive biases. If set to ``False``, the layer will not learn additive biases.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = ``flax.linen.initializers.zeros``
Used for initializing bias of QKVO projections, Used for initializing bias of QKVO projections,
FC1 and FC2. It is only used when :attr:`use_bias=True`. FC1 and FC2. It is only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
apply_residual_connection_post_layernorm: bool, default = False apply_residual_connection_post_layernorm: bool, default = False
If set to True, residual connections are taken from the output If set to ``True``, residual connections are taken from the output
of layer norm (default is taken from input of layer norm) of layer norm (default is taken from input of layer norm)
output_layernorm: bool, default = False output_layernorm: bool, default = False
If set to True, layer normalization is applied on the output side, If set to ``True``, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation. normalization on the input side, before the QKV transformation.
float32_attention_logits: bool, default = False float32_attention_logits: bool, default = False
...@@ -1660,43 +1833,43 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1660,43 +1833,43 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
For fused attention backend, the accumulation is always float32 without the perf overhead. For fused attention backend, the accumulation is always float32 without the perf overhead.
layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
If set to TransformerLayerType.DECODER, an additional cross-attention block If set to TransformerLayerType.DECODER, an additional cross-attention block
is added after self-attention.this can be used for structures like `T5` is added after self-attention.this can be used for structures like T5
Transformer in conjunction with the TransformerLayerType.ENCODER option. Transformer in conjunction with the TransformerLayerType.ENCODER option.
self_attn_mask_type: str, default = 'causal' self_attn_mask_type: str, default = 'causal'
This parameter specifies the type of attention mask to be applied during the softmax This parameter specifies the type of attention mask to be applied during the softmax
operation in the self attention. operation in the self attention.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
Each described below: Each described below:
* no_mask: No attention mask is applied. This means the self attention will consider the * ``no_mask``: No attention mask is applied. This means the self attention will consider the
full sequence without any restrictions. full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence. * ``padding``: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
:attr:`__call__` method to specify the padding positions. :attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs, * ``causal``: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it. from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks. * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
.. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
self_attn_bias_type: Optional[str], default = None self_attn_bias_type: Optional[str], default = None
Type of the attention bias passed into the self attention. Type of the attention bias passed into the self attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. Available options: ``{'no_bias', 'pre_scale_bias', 'post_scale_bias'}``.
When default is present, the type is automatically decided by the MHA's bias parameter. When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used. Where it is ``'post_scale_bias'`` if there is bias. Otherwise ``'no_bias'`` is used.
enable_relative_embedding: bool, default = True enable_relative_embedding: bool, default = True
Whether to enable relative embedding as shifting of attention logits. Whether to enable relative embedding as shifting of attention logits.
relative_embedding: flax.linen.Module, default = None relative_embedding: flax.linen.Module, default = None
The module for relative embedding execution, only used when The module for relative embedding execution, only used when
:attr:`enable_relative_embedding=True`. Default is None, which will create :attr:`enable_relative_embedding=True`. Default is ``None``, which will create
an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`. an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
Default: RelativePositionBiases( num_buckets=32, max_distance=128, Default: ``RelativePositionBiases( num_buckets=32, max_distance=128,
num_attention_heads=self.num_attention_heads, dtype=self.dtype, num_attention_heads=self.num_attention_heads, dtype=self.dtype,
embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'), embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
name='relpos_bias') name='relpos_bias')``
enable_rotary_pos_emb: bool, default = False enable_rotary_pos_emb: bool, default = False
Whether to enable rotary position embedding to projected query and key in MHA. Whether to enable rotary position embedding to projected query and key in MHA.
rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000) rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
...@@ -1704,23 +1877,50 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1704,23 +1877,50 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
only used when :attr:`enable_rotary_pos_emb=True` only used when :attr:`enable_rotary_pos_emb=True`
rotary_pos_emb_group_method: str, default = 'consecutive' rotary_pos_emb_group_method: str, default = 'consecutive'
Indicate the method to couple the coordinates. It should be one of Indicate the method to couple the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`, ``['consecutive', 'alternate']``. ``'alternate'`` is to pair index :math:`i` with :math:`i + d/2`,
where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with where :math:`d` is the hidden dimension. ``'consecutive'`` pairs index :math:`i` with
:math:`i + 1`. :math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none' low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj', ``['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
'exclude_output_proj', 'exclude_mlp'] 'exclude_output_proj', 'exclude_mlp']``
low_rank_adaptation_dim: int, default = 32 low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True` :attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output. The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora\_output`. None means no scaling. :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
enable_sequence_parallel: bool, default = False enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot. Whether to enable sequence parallelism to operations except dot.
window_size: Optional[Tuple[int, int]], default = None window_size: Optional[Tuple[int, int]], default = None
Sliding window size. Default value is no sliding window. Sliding window size. Default value is no sliding window.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Softmax type as described in the paper
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
* ``'vanilla'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
* ``'off-by-one'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
Only supported for fused attention backend.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -1730,19 +1930,19 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1730,19 +1930,19 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
When > 0.0, applies stochastic depth per sample in the main When > 0.0, applies stochastic depth per sample in the main
path of the residual block. path of the residual block.
fuse_qkv_params: bool, default = True fuse_qkv_params: bool, default = True
If set to True, `TransformerLayer` module exposes a single fused If set to ``True``, ``TransformerLayer`` module exposes a single fused
parameter for query-key-value for self-attention and key-value for parameter for query-key-value for self-attention and key-value for
cross-attention. cross-attention.
transpose_batch_sequence: bool, default = False transpose_batch_sequence: bool, default = False
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. if set to ``True``, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in ``(seqlen, batch, hidden)``, otherwise ``(batch, seqlen, hidden)``.
scale_attn_logits: bool, default = False scale_attn_logits: bool, default = False
Indicate whether to scale attention logits. Indicate whether to scale attention logits.
if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`, if set to ``True``, :math:`\frac{Q \cdot K^T}{\sqrt{head\_dim}}`,
else :math:`Q*K` else :math:`Q \cdot K^T`
scaled_query_init: bool, default = `True` scaled_query_init: bool, default = True
Whether to scale WQ on initialization by :math:`\sqrt{head_dim}` Whether to scale WQ on initialization by :math:`\sqrt{head\_dim}`
""" """
hidden_size: int = 512 hidden_size: int = 512
...@@ -1786,6 +1986,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1786,6 +1986,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
softmax_type: str = "vanilla"
def __post_init__(self): def __post_init__(self):
if self.mha_kernel_init is None: if self.mha_kernel_init is None:
...@@ -1824,7 +2025,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1824,7 +2025,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
attention_mask : jax.numpy.ndarray, default = None attention_mask : jax.numpy.ndarray, default = None
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out self-attention softmax input.
:attr:`True` means mask out the corresponding values. :attr:`True` means mask out the corresponding values.
Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'. Ignored when :attr:`self.self_attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
encoder_decoder_mask: jax.numpy.ndarray, default = None encoder_decoder_mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out cross-attention softmax input when Boolean tensor used to mask out cross-attention softmax input when
:attr:`layer_type=TransformerLayerType.DECODER`. :attr:`layer_type=TransformerLayerType.DECODER`.
...@@ -1946,6 +2147,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1946,6 +2147,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_init=self.bias_init, bias_init=self.bias_init,
name=mha_name, name=mha_name,
window_size=self.window_size, window_size=self.window_size,
softmax_type=self.softmax_type,
)(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode) )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode)
def hidden_dropout(x, deterministic): def hidden_dropout(x, deterministic):
...@@ -2024,6 +2226,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -2024,6 +2226,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_init=self.bias_init, bias_init=self.bias_init,
name="encoder_decoder_attention", name="encoder_decoder_attention",
window_size=self.window_size, window_size=self.window_size,
softmax_type=self.softmax_type,
)(x, encoded, encoder_decoder_mask, deterministic=deterministic) )(x, encoded, encoder_decoder_mask, deterministic=deterministic)
y = with_sharding_constraint_by_logical_axes( y = with_sharding_constraint_by_logical_axes(
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Layer normalization operations for Transformer Engine in JAX. """Layer normalization operations for Transformer Engine in JAX.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Fused Layer normalization and dense layer transformation operations for Transformer Engine in JAX. """Fused Layer normalization and dense layer transformation operations for Transformer Engine in JAX.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX. """Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""MoE Permutation API for JAX.
This module provides high-level token dispatch and combine operations for
Mixture of Experts (MoE) models with proper automatic differentiation support.
Token Dispatch (Permute):
- Forward: Permute tokens according to routing map (scatter to experts)
- Backward: Unpermute gradients (gather from experts)
Token Combine (Unpermute):
- Forward: Unpermute tokens and merge with weights (gather from experts)
- Backward: Permute gradients (scatter to experts)
"""
from functools import partial
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
from transformer_engine.jax.triton_extensions.permutation import (
make_row_id_map,
permute_with_mask_map,
permute_with_mask_map_and_pad,
unpermute_with_mask_map,
unpermute_with_mask_map_and_unpad,
unpermute_bwd_with_merging_probs,
unpermute_bwd_with_merging_probs_and_unpad,
make_chunk_sort_map,
sort_chunks_by_map,
)
__all__ = [
"token_dispatch",
"token_combine",
"sort_chunks_by_index",
]
def token_dispatch(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
num_out_tokens: int,
probs: Optional[jnp.ndarray] = None,
align_size: Optional[int] = None,
) -> Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
]:
"""
Dispatch tokens to experts based on routing map.
This is the forward pass of the MoE permutation. Tokens are scattered
to their designated experts according to the routing map. The row_id_map
is computed internally from the routing_map.
Optionally supports fused padding for alignment when `align_size` is provided.
This is useful for efficient matrix multiplications that require aligned tensor
dimensions. The padding is computed internally from the routing_map.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [batch, sequence, hidden_size] or [num_tokens, hidden_size].
routing_map : jnp.ndarray
Routing mask of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
Values: 1 = routed, 0 = not routed.
num_out_tokens : int
The number of output tokens after permutation (before padding). For the dropless
case, this should be equal to the sum of routing_map. Must be provided explicitly
for JIT compatibility since output shape must be known at compile time.
probs : Optional[jnp.ndarray]
Optional routing probabilities of shape [batch, sequence, num_experts] or
[num_tokens, num_experts]. If provided, permuted_probs will be returned.
align_size : Optional[int]
Optional alignment size for padding. If provided, outputs will be padded to
align each expert's tokens to a multiple of this size. The output buffer is
allocated with worst-case size, rounded down to align_size:
((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size
This enables full JIT compatibility.
Returns
-------
output : jnp.ndarray
Permuted output tensor of shape [num_out_tokens, hidden_size] without padding,
or [worst_case_padded_size, hidden_size] when using padding fusion.
With padding, the actual used portion may be smaller than the buffer; check
actual_num_out_tokens (sum of target_tokens_per_expert) for the actual size.
permuted_probs : Optional[jnp.ndarray]
Permuted probabilities of shape [num_out_tokens] or [worst_case_padded_size],
or None if probs was not provided.
row_id_map : jnp.ndarray
Row ID map for use in token_combine (shape [num_tokens, num_experts * 2 + 1]).
pad_offsets : Optional[jnp.ndarray]
Per-expert cumulative padding offsets of shape [num_experts] when using padding,
None otherwise. Pass this to token_combine when unpadding is needed.
target_tokens_per_expert : Optional[jnp.ndarray]
Aligned token counts per expert of shape [num_experts] when using padding,
None otherwise.
Note
----
**JIT Compatibility:**
This function is fully JIT-compatible. When using padding (align_size provided),
the output buffer is allocated with a fixed worst-case size that depends only on
compile-time constants (num_out_tokens, num_experts, align_size). The actual
padding offsets (pad_offsets) and aligned token counts (target_tokens_per_expert)
are computed internally from the routing_map and can be traced values.
The worst-case output size is:
((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size
This accounts for the maximum possible padding when each expert needs (align_size - 1)
extra tokens to align, rounded down to align_size for buffer alignment.
"""
use_padding = align_size is not None
num_experts = routing_map.shape[-1]
if use_padding:
# Compute worst-case output size (compile-time constant)
# This is the maximum possible size when each expert needs max padding
worst_case_out_tokens = (
(num_out_tokens + num_experts * (align_size - 1)) // align_size
) * align_size
else:
worst_case_out_tokens = num_out_tokens
return _token_dispatch(
inp, routing_map, probs, num_out_tokens, worst_case_out_tokens, align_size, use_padding
)
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
def _token_dispatch(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
probs: Optional[jnp.ndarray],
num_out_tokens: int,
worst_case_out_tokens: int,
align_size: Optional[int],
use_padding: bool,
) -> Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
]:
"""Internal token_dispatch with custom VJP."""
(output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert), _ = (
_token_dispatch_fwd_rule(
inp,
routing_map,
probs,
num_out_tokens,
worst_case_out_tokens,
align_size,
use_padding,
)
)
return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert
def _token_dispatch_fwd_rule(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
probs: Optional[jnp.ndarray],
num_out_tokens: int,
worst_case_out_tokens: int,
align_size: Optional[int],
use_padding: bool,
) -> Tuple[
Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
],
Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
]:
"""Forward pass rule for token_dispatch."""
# Validate input dimensions
assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D"
assert routing_map.ndim in [2, 3], f"routing_map must be 2D or 3D, got {routing_map.ndim}D"
# Infer dimensions from input shapes
num_tokens = inp.shape[0] * inp.shape[1] if inp.ndim == 3 else inp.shape[0]
hidden_size = inp.shape[-1]
num_experts = routing_map.shape[-1]
# Verify consistency between inp and routing_map
routing_num_tokens = (
routing_map.shape[0] * routing_map.shape[1]
if routing_map.ndim == 3
else routing_map.shape[0]
)
assert num_tokens == routing_num_tokens, (
f"Token count mismatch: inp has {num_tokens} tokens, "
f"routing_map has {routing_num_tokens} tokens"
)
# Always compute row_id_map internally from routing_map
row_id_map = make_row_id_map(routing_map, num_tokens, num_experts)
with_probs = probs is not None
if use_padding:
# Compute tokens_per_expert internally from routing_map
# This can be a traced value since output shape uses worst_case_out_tokens
tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32)
# Calculate aligned token counts per expert
target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype(
jnp.int32
)
# Compute pad_offsets: cumulative padding for each expert
# pad_offsets[i] = sum of (target - actual) for experts 0..i-1
pad_lengths = target_tokens_per_expert - tokens_per_expert
cum_pad = jnp.cumsum(pad_lengths)
pad_offsets = jnp.concatenate([jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]])
# Use worst_case_out_tokens as the output buffer size (compile-time constant)
# The actual used size is sum(target_tokens_per_expert), which may be smaller.
# Unused positions will be zero-initialized by the kernel.
output, permuted_probs = permute_with_mask_map_and_pad(
inp,
row_id_map,
probs,
pad_offsets,
num_tokens,
num_experts,
worst_case_out_tokens,
hidden_size,
align_size=align_size,
)
else:
# No padding
pad_offsets = None
target_tokens_per_expert = None
output, permuted_probs = permute_with_mask_map(
inp,
row_id_map,
probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
# Return (primals, residuals)
residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs)
return (
output,
permuted_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
), residuals
def _token_dispatch_bwd_rule(
_num_out_tokens: int,
_worst_case_out_tokens: int,
_align_size: Optional[int],
_use_padding: bool,
residuals: Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
g: Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
],
) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray]]:
"""Backward pass rule for token_dispatch.
Returns gradients for (inp, routing_map, probs).
routing_map gradient is None since it's a discrete routing decision.
"""
row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs = residuals
output_grad, permuted_probs_grad, _, _, _ = g # Ignore row_id_map, pad_offsets, target grads
# Backward: unpermute gradients (gather from experts back to tokens)
if pad_offsets is not None:
inp_grad, probs_grad = unpermute_with_mask_map_and_unpad(
output_grad,
row_id_map,
None, # No merging probs
permuted_probs_grad if with_probs else None,
pad_offsets,
num_tokens,
num_experts,
hidden_size,
)
else:
inp_grad, probs_grad = unpermute_with_mask_map(
output_grad,
row_id_map,
None, # No merging probs
permuted_probs_grad if with_probs else None,
num_tokens,
num_experts,
hidden_size,
)
# Return gradients for (inp, routing_map, probs)
# routing_map is non-differentiable (discrete routing), so return None
return inp_grad, None, probs_grad if with_probs else None
_token_dispatch.defvjp(_token_dispatch_fwd_rule, _token_dispatch_bwd_rule)
# =============================================================================
# Token Combine (Unpermute) with VJP
# =============================================================================
def token_combine(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray] = None,
pad_offsets: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""
Combine tokens from experts back to original token positions.
This is the forward pass of MoE unpermutation. Tokens are gathered from
experts and merged (optionally weighted by merging_probs).
Optionally supports fused unpadding when `pad_offsets` is provided (from
token_dispatch with padding enabled).
Parameters
----------
inp : jnp.ndarray
Input tensor from experts of shape [num_out_tokens, hidden_size]
(or [num_out_tokens_padded, hidden_size] when using unpadding).
row_id_map : jnp.ndarray
Row ID map from token_dispatch of shape [num_tokens, num_experts * 2 + 1].
merging_probs : Optional[jnp.ndarray]
Merging weights of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
If provided, tokens from different experts are weighted-summed.
If None, tokens are summed directly.
pad_offsets : Optional[jnp.ndarray]
Per-expert cumulative padding offsets of shape [num_experts] from token_dispatch.
If provided, fused unpadding will be performed. This should be the pad_offsets
returned by token_dispatch when using padding.
Returns
-------
output : jnp.ndarray
Combined output tensor of shape [num_tokens, hidden_size].
"""
return _token_combine(inp, row_id_map, merging_probs, pad_offsets)
@jax.custom_vjp
def _token_combine(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray],
pad_offsets: Optional[jnp.ndarray],
) -> jnp.ndarray:
"""Internal token_combine with custom VJP."""
output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs, pad_offsets)
return output
def _token_combine_fwd_rule(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray],
pad_offsets: Optional[jnp.ndarray],
) -> Tuple[
jnp.ndarray,
Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
int,
int,
int,
int,
],
]:
"""Forward pass rule for token_combine."""
# Infer dimensions from row_id_map shape: [num_tokens, num_experts * 2 + 1]
num_tokens = row_id_map.shape[0]
num_experts = (row_id_map.shape[1] - 1) // 2
hidden_size = inp.shape[-1]
num_out_tokens = inp.shape[0]
# Call triton extension with or without unpadding
if pad_offsets is not None:
output, _ = unpermute_with_mask_map_and_unpad(
inp,
row_id_map,
merging_probs,
None, # No permuted probs to unpermute
pad_offsets,
num_tokens,
num_experts,
hidden_size,
)
else:
output, _ = unpermute_with_mask_map(
inp,
row_id_map,
merging_probs,
None, # No permuted probs to unpermute
num_tokens,
num_experts,
hidden_size,
)
# Return (primal, residuals)
# Include inp in residuals for backward with merging_probs
residuals = (
row_id_map,
pad_offsets,
inp,
merging_probs,
num_tokens,
num_experts,
hidden_size,
num_out_tokens,
)
return output, residuals
def _token_combine_bwd_rule(
residuals: Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
int,
int,
int,
int,
],
g: jnp.ndarray,
) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray], None]:
"""Backward pass rule for token_combine.
Returns gradients for: (inp, row_id_map, merging_probs, pad_offsets)
row_id_map and pad_offsets are integer arrays, so their gradients are None.
"""
(
row_id_map,
pad_offsets,
fwd_input,
merging_probs,
num_tokens,
num_experts,
hidden_size,
num_out_tokens,
) = residuals
output_grad = g
with_merging_probs = merging_probs is not None
if with_merging_probs:
# Use specialized backward kernel that properly scales by merging_probs
if pad_offsets is not None:
inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs_and_unpad(
output_grad,
row_id_map,
fwd_input,
merging_probs,
pad_offsets,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
# The backward kernel only writes to positions that tokens map to.
# Padded positions may contain uninitialized (NaN) values - replace with zeros.
inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad)
else:
inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs(
output_grad,
row_id_map,
fwd_input,
merging_probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
else:
# Simple case: just permute gradients back
if pad_offsets is not None:
# Note: align_size uses default (128) since buffer sizes are already
# determined from forward pass (stored in residuals as num_out_tokens)
inp_grad, _ = permute_with_mask_map_and_pad(
output_grad,
row_id_map,
None,
pad_offsets,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
align_size=128, # Default, sizes already computed in forward
)
# The permute kernel only writes to positions that tokens map to.
# Padded positions may contain uninitialized (NaN) values - replace with zeros.
inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad)
else:
inp_grad, _ = permute_with_mask_map(
output_grad,
row_id_map,
None,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
merging_probs_grad = None
# Return gradients for: inp, row_id_map, merging_probs, pad_offsets
# row_id_map and pad_offsets are integer arrays, so their gradients are None
return inp_grad, None, merging_probs_grad, None
_token_combine.defvjp(_token_combine_fwd_rule, _token_combine_bwd_rule)
# =============================================================================
# Chunk Sort with VJP
# =============================================================================
def sort_chunks_by_index(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Sort chunks of tokens according to sorted indices.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [batch, sequence, hidden_size] or [num_tokens, hidden_size].
split_sizes : jnp.ndarray
Sizes of each chunk of shape [num_splits].
sorted_indices : jnp.ndarray
Permutation indices for chunks of shape [num_splits].
Returns
-------
output : jnp.ndarray
Sorted output tensor of shape [num_tokens, hidden_size].
row_id_map : jnp.ndarray
Row ID map for reversing the sort.
"""
return _sort_chunks_by_index(inp, split_sizes, sorted_indices)
@partial(jax.custom_vjp, nondiff_argnums=(1, 2))
def _sort_chunks_by_index(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Internal sort_chunks_by_index with custom VJP."""
(output, row_id_map), _ = _sort_chunks_by_index_fwd_rule(inp, split_sizes, sorted_indices)
return output, row_id_map
def _sort_chunks_by_index_fwd_rule(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, int, int]]:
"""Forward pass rule for sort_chunks_by_index."""
# Validate input dimensions
assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D"
# Infer dimensions from input shape
num_tokens = inp.shape[0] * inp.shape[1] if inp.ndim == 3 else inp.shape[0]
hidden_size = inp.shape[-1]
num_splits = split_sizes.shape[0]
row_id_map = make_chunk_sort_map(split_sizes, sorted_indices, num_tokens, num_splits)
output, _ = sort_chunks_by_map(
inp,
row_id_map,
None, # No probs
num_tokens,
hidden_size,
is_forward=True,
)
# Return (primals, residuals)
residuals = (row_id_map, num_tokens, hidden_size)
return (output, row_id_map), residuals
def _sort_chunks_by_index_bwd_rule(
_split_sizes: jnp.ndarray,
_sorted_indices: jnp.ndarray,
residuals: Tuple[jnp.ndarray, int, int],
g: Tuple[jnp.ndarray, jnp.ndarray],
) -> Tuple[jnp.ndarray]:
"""Backward pass rule for sort_chunks_by_index."""
row_id_map, num_tokens, hidden_size = residuals
output_grad, _ = g
# Backward: reverse the sort
inp_grad, _ = sort_chunks_by_map(
output_grad,
row_id_map,
None,
num_tokens,
hidden_size,
is_forward=False,
)
return (inp_grad,)
_sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
""" """
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
""" """
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Randomized Hadamard Transform (RHT) utilities for JAX.""" """Randomized Hadamard Transform (RHT) utilities for JAX."""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
""" """
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
""" """
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment