Unverified Commit df39a7c2 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

Docs fix (#2301)



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lines lenght
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* subtitle --- fix in many files:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cross entropy _input -> input rename
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cross entropy _input -> input rename
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* a lot of small fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* torch_version() change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add missing module and fix warnings
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* removed training whitespace:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update docs/api/pytorch.rst
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* Fix import
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix more imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix NumPy docstring parameter spacing and indentation

- Standardize parameter documentation to use 'param : type' format (space before and after colon) per NumPy style guide
- Fix inconsistent indentation in cpu_offload.py docstring
- Modified 51 Python files across transformer_engine/pytorch
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ca468ebe
...@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a ...@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeLayout,
DelayedScaleQuantizer, DelayedScaleQuantizer,
ScalingMode, ScalingMode,
QuantizeLayout,
) )
......
...@@ -39,12 +39,12 @@ from ..quantize import ( ...@@ -39,12 +39,12 @@ from ..quantize import (
Quantizer, Quantizer,
GroupedQuantizer, GroupedQuantizer,
QuantizerSet, QuantizerSet,
QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv, apply_padding_to_scale_inv,
get_quantize_config_with_recipe, get_quantize_config_with_recipe,
get_global_quantize_recipe, get_global_quantize_recipe,
QuantizeLayout,
) )
from .misc import get_padded_spec, is_all_reduce_in_float32 from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import ( from ..sharding import (
......
...@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1): ...@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
transpose. Note, transpose_axis should be greater than static_axis_boundary transpose. Note, transpose_axis should be greater than static_axis_boundary
examples: examples:
X in shape (dim0, dim1, dim2, dim3, dim4) X of shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis == 2 static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1) Xt = (dim2, dim3, dim4, dim0, dim1)
......
...@@ -35,9 +35,9 @@ from ..sharding import ( ...@@ -35,9 +35,9 @@ from ..sharding import (
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeLayout,
DelayedScaleQuantizer, DelayedScaleQuantizer,
ScalingMode, ScalingMode,
QuantizeLayout,
) )
......
...@@ -40,11 +40,11 @@ from ..quantize import ( ...@@ -40,11 +40,11 @@ from ..quantize import (
GroupedScaledTensor1x, GroupedScaledTensor1x,
Quantizer, Quantizer,
GroupedQuantizer, GroupedQuantizer,
QuantizeLayout,
ScalingMode, ScalingMode,
compute_scale_from_amax, compute_scale_from_amax,
NoScaleTensor, NoScaleTensor,
get_rht_matrix, get_rht_matrix,
QuantizeLayout,
) )
......
...@@ -21,12 +21,12 @@ from .quantize import ( ...@@ -21,12 +21,12 @@ from .quantize import (
ScaledTensorFactory, ScaledTensorFactory,
ScaledTensor, ScaledTensor,
ScalingMode, ScalingMode,
QuantizeLayout,
QuantizerSet, QuantizerSet,
noop_quantizer_set, noop_quantizer_set,
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
TensorUsage, TensorUsage,
QuantizeLayout,
) )
......
...@@ -279,26 +279,26 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -279,26 +279,26 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization. Indicate the type of layer normalization.
zero_centered_gamma : bool, default = False zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to If set to ``True``, the LayerNorm formula changes to
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta (1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'. This parameter is only applicable for ``'layernorm'``.
The default of `scale_init` will also be changed. See `scale_init`. The default of ``scale_init`` will also be changed. See ``scale_init``.
scale_init : Initializer, default = None scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma. If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`. If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is `flax.linen.initializers.ones`. Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh. The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
bias_init : Initializer, default = flax.linen.initializers.zeros bias_init : Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`, Used for initializing shift factors :math:`\beta`,
only used when :attr:`layernorm_type='layernorm'`. only used when :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes : Tuple[str, ...], default = ('embed', ) bias_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
only used when :attr:`layernorm_type='layernorm'`. only used when :attr:`layernorm_type='layernorm'`.
...@@ -424,15 +424,15 @@ class DenseGeneral(TransformerEngineBase): ...@@ -424,15 +424,15 @@ class DenseGeneral(TransformerEngineBase):
kernel_init : Initializer, default = kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights. Used for initializing weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes : Tuple[str, ...], default = () kernel_axes : Tuple[str, ...], default = ()
The name of axes used to shard the weights with a corresponding mesh. The name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting. Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias. If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`. Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes: Tuple[str, ...], default = () bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh, The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`. only used when :attr:`use_bias=True`.
...@@ -443,12 +443,12 @@ class DenseGeneral(TransformerEngineBase): ...@@ -443,12 +443,12 @@ class DenseGeneral(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True` :attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output. The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
input_axes: Tuple[str, ...], default = None input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input, like Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
Optimization parameters Optimization parameters
...@@ -597,48 +597,48 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -597,48 +597,48 @@ class LayerNormDenseGeneral(TransformerEngineBase):
epsilon : float, default = 1e-6 epsilon : float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to If set to ``True``, the LayerNorm formula changes to
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta (1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'. This parameter is only applicable for ``'layernorm'``.
The default of `scale_init` will also be changed. See `scale_init` The default of ``scale_init`` will also be changed. See ``scale_init``
scale_init : Initializer, default = None scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma. If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`. If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is `flax.linen.initializers.ones`. Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh, The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only used when :attr:`enable_layernorm=True`. only used when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros ln_bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`, Used for initializing shift factors :math:`\beta`,
only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
ln_bias_axes: Tuple[str, ...], default = ('embed', ) ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default = kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights. Used for initializing weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes : Tuple[str, ...], default = () kernel_axes : Tuple[str, ...], default = ()
The name of axes used to shard the weights with a corresponding mesh. The name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting. Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias. If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`. Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes: Tuple[str, ...], default = () bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh, The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`. only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = False return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set ``False``, return ``None`` as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each dense layer. Indicate whether to enable low rank adaptation for each dense layer.
low_rank_adaptation_dim: int, default = 32 low_rank_adaptation_dim: int, default = 32
...@@ -646,16 +646,16 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -646,16 +646,16 @@ class LayerNormDenseGeneral(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True` :attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output. The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
dot_input_axes: Tuple[str, ...], default = None dot_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of dot, like Indicate the logical axes of sharding constraint to the input of dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
Optimization parameters Optimization parameters
...@@ -887,34 +887,34 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -887,34 +887,34 @@ class LayerNormMLP(TransformerEngineBase):
epsilon : float, default = 1e-6 epsilon : float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to If set to ``True``, the LayerNorm formula changes to
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta (1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'. This parameter is only applicable for ``'layernorm'``.
The default of `scale_init` will also be changed. See `scale_init`. The default of ``scale_init`` will also be changed. See ``scale_init``.
scale_init : Initializer, default = None scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma. If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`. If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is `flax.linen.initializers.ones`. Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh, The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only used when :attr:`enable_layernorm=True`. only used when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros ln_bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`, Used for initializing shift factors :math:`\beta`,
only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
ln_bias_axes: Tuple[str, ...], default = ('embed', ) ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default = kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing the weights of both dense layer transformations. Used for initializing the weights of both dense layer transformations.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp') kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
The name of axes used to shard the weights with a corresponding mesh for The name of axes used to shard the weights with a corresponding mesh for
the weight of the first dense layer transformation. the weight of the first dense layer transformation.
...@@ -923,10 +923,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -923,10 +923,10 @@ class LayerNormMLP(TransformerEngineBase):
the weight of the second dense layer transformation. the weight of the second dense layer transformation.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting. Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias. If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`. Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes_1: Tuple[str, ...], default = ('mlp',) bias_axes_1: Tuple[str, ...], default = ('mlp',)
The name of axes used to shard bias with a corresponding mesh for The name of axes used to shard bias with a corresponding mesh for
the weight of the first dense layer transformation. the weight of the first dense layer transformation.
...@@ -937,7 +937,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -937,7 +937,7 @@ class LayerNormMLP(TransformerEngineBase):
Only used when :attr:`use_bias=True`. Only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = False return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set ``False``, return ``None`` as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('gelu',) activations: Sequence[Union[str, Callable]], default = ('gelu',)
The sequence of activation functions to apply after the first dense layer transformation. The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
...@@ -958,20 +958,20 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -958,20 +958,20 @@ class LayerNormMLP(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True`. :attr:`enable_low_rank_adaptation=True`.
low_rank_adaptation_alpha: float, default = None low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output. The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
dot_1_input_axes: Tuple[str, ...], default = None dot_1_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 1st dot, like Indicate the logical axes of sharding constraint to the input of 1st dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
dot_2_input_axes: Tuple[str, ...], default = None dot_2_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 2nd dot, like Indicate the logical axes of sharding constraint to the input of 2nd dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
ffn1_ckpt_name: str = "ffn1" ffn1_ckpt_name: str = "ffn1"
Checkpoint name for the output of the first fully-connected layer in the MLP block. Checkpoint name for the output of the first fully-connected layer in the MLP block.
......
...@@ -469,7 +469,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -469,7 +469,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
The hidden dimension of each attention head. The hidden dimension of each attention head.
num_attention_heads: int num_attention_heads: int
The number of attention heads. The number of attention heads.
num_gqa_groups: int, default = `None` num_gqa_groups: int, default = None
Number of GQA groups. When `None` is present, it is equal to num_attention_heads. Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_. `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
...@@ -482,32 +482,45 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -482,32 +482,45 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type: str, default = 'causal' attn_mask_type: str, default = 'causal'
This parameter specifies the type of attention mask to be applied during the softmax This parameter specifies the type of attention mask to be applied during the softmax
operation. operation.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
Each described below: Each described below:
* no_mask: No attention mask is applied. This means the attention will consider the * ``no_mask``: No attention mask is applied. This means the attention will consider the
full sequence without any restrictions. full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence. * ``padding``: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
:attr:`__call__` method to specify the padding positions. :attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs, * ``causal``: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it. from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks. * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. |
.. note:: THD format only supports 'padding' or 'causal_padding' mask type. .. note:: :attr:`mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
attn_mask_type mask/sequence_descriptor SWA softmax type |
--------------------------------------------------------------------------------------------
no_mask None None SCALED .. note:: THD format only supports ``'padding'`` or ``'causal_padding'`` mask type.
causal None None SCALED_UPPER_TRIANG_MASKED
causal None Yes SCALED_MASKED |
padding Required Yes/No SCALED_MASKED
padding_causal Required Yes/No SCALED_MASKED .. table::
:widths: auto
================== ============ ========== ==============================
attn_mask_type mask/sd SWA softmax type
================== ============ ========== ==============================
no_mask None None SCALED
causal None None SCALED_UPPER_TRIANG_MASKED
causal None Yes SCALED_MASKED
padding Required Yes/No SCALED_MASKED
padding_causal Required Yes/No SCALED_MASKED
================== ============ ========== ==============================
where sd stands for sequence_descriptor.
attn_bias_type: Optional[str], default = None attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention. Type of the attention bias passed in the attention.
...@@ -553,22 +566,40 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -553,22 +566,40 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Sliding window size. The default value is no sliding window. Sliding window size. The default value is no sliding window.
max_segments_per_seq: Optional[int], default = 1 max_segments_per_seq: Optional[int], default = 1
The maximum number of segments per sequence, also used for THD format (sequence packing). The maximum number of segments per sequence, also used for THD format (sequence packing).
context_parallel_causal_load_balanced (bool): context_parallel_causal_load_balanced: bool
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis: str
context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING. The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention. context_parallel_strategy: CPStrategy
The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
context_checkpoint_name: str
The name of the context checkpoint in the forward pass of fused attention.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper: Softmax type as described in the paper
`Efficient Streaming Language Models with Attention Sinks `Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_. <https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), * ``'vanilla'``:
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention .. math::
('zero sink' and 'learnable sink'). Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
* ``'off-by-one'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -631,7 +662,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -631,7 +662,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
mask: jax.numpy.ndarray, default = None mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input. Boolean tensor used to mask out the attention softmax input.
:attr:`True` means to mask out the corresponding values. :attr:`True` means to mask out the corresponding values.
Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'. Ignored when :attr:`self.attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
bias: jax.numpy.ndarray, default = None bias: jax.numpy.ndarray, default = None
A tensor used to shift attention softmax input. A tensor used to shift attention softmax input.
*: *:
...@@ -818,7 +849,7 @@ def rotary_pos_emb( ...@@ -818,7 +849,7 @@ def rotary_pos_emb(
): ):
""" """
Rotary Positional Embedding Rotary Positional Embedding
x should be in shape of x should be of shape
[Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
[Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True. [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
""" """
...@@ -956,7 +987,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -956,7 +987,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
The hidden dimension of each attention head. The hidden dimension of each attention head.
num_attention_heads: int num_attention_heads: int
The number of attention heads. The number of attention heads.
num_gqa_groups: int, default = `None` num_gqa_groups: int, default = None
Number of GQA groups. When `None` is present, it is equal to num_attention_heads. Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_. `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
...@@ -969,28 +1000,28 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -969,28 +1000,28 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type: str, default = 'causal' attn_mask_type: str, default = 'causal'
This parameter specifies the type of attention mask to be applied during the softmax This parameter specifies the type of attention mask to be applied during the softmax
operation. operation.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
Each described below: Each described below:
* no_mask: No attention mask is applied. This means the attention will consider the * ``no_mask``: No attention mask is applied. This means the attention will consider the
full sequence without any restrictions. full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence. * ``padding``: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
:attr:`__call__` method to specify the padding positions. :attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs, * ``causal``: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it. from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks. * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. .. note:: :attr:`mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
attn_bias_type: Optional[str], default = None attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention. Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. Available options: ``{'no_bias', 'pre_scale_bias', 'post_scale_bias'}``.
When default is present, the type is automatically decided by the MHA's bias parameter. When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used. Where it is ``'post_scale_bias'`` if there is bias. Otherwise ``'no_bias'`` is used.
dropout_rng_name: str, default = 'dropout' dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that is used The key in given RNGs via flax.linen.Module.apply that is used
to generate Dropout masks in the core attention. to generate Dropout masks in the core attention.
...@@ -999,27 +1030,27 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -999,27 +1030,27 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
layernorm_epsilon: float, default = 1e-6 layernorm_epsilon: float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma: bool, default = False zero_centered_gamma: bool, default = False
If set to `True`, the LayerNorm formula changes to If set to ``True``, the LayerNorm formula changes to
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta (1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'. This parameter is only applicable for ``'layernorm'``.
kernel_init: Initializer, default = kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal') ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')``
Used for initializing the QKV and output projection weights. Used for initializing the QKV and output projection weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether or not to enable bias shifting for QKV and output projections. Indicate whether or not to enable bias shifting for QKV and output projections.
If set to False, the layer will not learn additive biases. If set to ``False``, the layer will not learn additive biases.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = ``flax.linen.initializers.zeros``
Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`. Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
input_layernorm: bool, default = True input_layernorm: bool, default = True
If set to False, layer normalization to the input is not applied. If set to ``False``, layer normalization to the input is not applied.
return_layernorm_output: bool, default = False return_layernorm_output: bool, default = False
If set to True, output of layernorm is returned from the forward together with the output If set to ``True``, output of layernorm is returned from the forward together with the output
of the linear transformation. of the linear transformation.
Example use case: residual connection for transformer module is taken post layernorm. Example use case: residual connection for transformer module is taken post layernorm.
enable_rotary_pos_emb: bool, default = False enable_rotary_pos_emb: bool, default = False
...@@ -1029,17 +1060,17 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1029,17 +1060,17 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
only used when :attr:`enable_rotary_pos_emb=True` only used when :attr:`enable_rotary_pos_emb=True`
rotary_pos_emb_group_method: str, default = 'consecutive' rotary_pos_emb_group_method: str, default = 'consecutive'
Indicate the method to coupled the coordinates. It should be one of Indicate the method to coupled the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2` ``['consecutive', 'alternate']``. ``'alternate'`` is to pair index :math:`i` with :math:`i + d/2`
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`. , d is the hidden dimension. ``'consecutive'`` pairs index :math:`i` with :math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none' low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj'] ``['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']``
low_rank_adaptation_dim: int, default = 32 low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True` :attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output. The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
enable_sequence_parallel: bool, default = False enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot. Whether to enable sequence parallelism to operations except dot.
num_heads: int, default = None num_heads: int, default = None
...@@ -1066,8 +1097,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1066,8 +1097,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
scale_attn_logits: bool, default = False scale_attn_logits: bool, default = False
Indicate whether to scale attention logits. Indicate whether to scale attention logits.
If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`, If set to True, :math:`\frac{Q \cdot K^T}{\sqrt{head\_dim}}`,
else :math:`Q*K` else :math:`Q \cdot K^T`
scaled_query_init: bool, default = True scaled_query_init: bool, default = True
Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}` Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
float32_logits: bool, default = False float32_logits: bool, default = False
...@@ -1078,16 +1109,31 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1078,16 +1109,31 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
window_size: Optional[Tuple[int, int]], default = None window_size: Optional[Tuple[int, int]], default = None
Sliding window size. Default value is no sliding window. Sliding window size. Default value is no sliding window.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper: Softmax type as described in the paper
`Efficient Streaming Language Models with Attention Sinks `Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_. <https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), * ``'vanilla'``:
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention .. math::
('zero sink' and 'learnable sink'). Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
* ``'off-by-one'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
""" """
head_dim: int head_dim: int
...@@ -1202,7 +1248,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1202,7 +1248,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
mask: jax.numpy.ndarray, default = None mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input. Boolean tensor used to mask out the attention softmax input.
:attr:`True` means mask out the corresponding values. :attr:`True` means mask out the corresponding values.
Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'. Ignored when :attr:`self.attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
bias: jax.numpy.ndarray, default = None bias: jax.numpy.ndarray, default = None
A tensor used to shift the attention softmax input. A tensor used to shift the attention softmax input.
* *
...@@ -1688,7 +1734,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1688,7 +1734,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Intermediate size to which input samples are projected. Intermediate size to which input samples are projected.
num_attention_heads: int, default = 8 num_attention_heads: int, default = 8
Number of attention heads in the transformer layer. Number of attention heads in the transformer layer.
num_gqa_groups: int, default = `None` num_gqa_groups: int, default = None
Number of GQA groups. When `None` is present, it is equal to num_attention_heads. Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_. `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
...@@ -1722,31 +1768,31 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1722,31 +1768,31 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
The key in given RNGs via flax.linen.Module.apply that for The key in given RNGs via flax.linen.Module.apply that for
generating Dropout masks in the Multi-Head Attention. generating Dropout masks in the Multi-Head Attention.
mha_kernel_init: Initializer, default = mha_kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal') ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')``
Used for initializing weights of QKV and Output projection weights. Used for initializing weights of QKV and Output projection weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
mlp_kernel_init: Initializer, default = mlp_kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')``
Used for initializing weights of FC1 and FC2 layers. Used for initializing weights of FC1 and FC2 layers.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
mlp_activations: Sequence[str], default = ('gelu', ) mlp_activations: Sequence[str], default = ('gelu', )
The sequence of activation functions to apply after the first linear transformation. The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
mlp_activation_params: dict = None mlp_activation_params: dict = None
This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment This is only used when ``('clamped_silu', 'clamped_linear')`` is in :attr:`mlp_activations`. At the moment
ClampedSwiglu is the only activation that requires parameters. ``ClampedSwiglu`` is the only activation that requires parameters.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
If set to False, the layer will not learn additive biases. If set to ``False``, the layer will not learn additive biases.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = ``flax.linen.initializers.zeros``
Used for initializing bias of QKVO projections, Used for initializing bias of QKVO projections,
FC1 and FC2. It is only used when :attr:`use_bias=True`. FC1 and FC2. It is only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
apply_residual_connection_post_layernorm: bool, default = False apply_residual_connection_post_layernorm: bool, default = False
If set to True, residual connections are taken from the output If set to ``True``, residual connections are taken from the output
of layer norm (default is taken from input of layer norm) of layer norm (default is taken from input of layer norm)
output_layernorm: bool, default = False output_layernorm: bool, default = False
If set to True, layer normalization is applied on the output side, If set to ``True``, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation. normalization on the input side, before the QKV transformation.
float32_attention_logits: bool, default = False float32_attention_logits: bool, default = False
...@@ -1754,43 +1800,43 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1754,43 +1800,43 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
For fused attention backend, the accumulation is always float32 without the perf overhead. For fused attention backend, the accumulation is always float32 without the perf overhead.
layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
If set to TransformerLayerType.DECODER, an additional cross-attention block If set to TransformerLayerType.DECODER, an additional cross-attention block
is added after self-attention.this can be used for structures like `T5` is added after self-attention.this can be used for structures like T5
Transformer in conjunction with the TransformerLayerType.ENCODER option. Transformer in conjunction with the TransformerLayerType.ENCODER option.
self_attn_mask_type: str, default = 'causal' self_attn_mask_type: str, default = 'causal'
This parameter specifies the type of attention mask to be applied during the softmax This parameter specifies the type of attention mask to be applied during the softmax
operation in the self attention. operation in the self attention.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
Each described below: Each described below:
* no_mask: No attention mask is applied. This means the self attention will consider the * ``no_mask``: No attention mask is applied. This means the self attention will consider the
full sequence without any restrictions. full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence. * ``padding``: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
:attr:`__call__` method to specify the padding positions. :attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs, * ``causal``: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it. from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks. * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
.. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
self_attn_bias_type: Optional[str], default = None self_attn_bias_type: Optional[str], default = None
Type of the attention bias passed into the self attention. Type of the attention bias passed into the self attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. Available options: ``{'no_bias', 'pre_scale_bias', 'post_scale_bias'}``.
When default is present, the type is automatically decided by the MHA's bias parameter. When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used. Where it is ``'post_scale_bias'`` if there is bias. Otherwise ``'no_bias'`` is used.
enable_relative_embedding: bool, default = True enable_relative_embedding: bool, default = True
Whether to enable relative embedding as shifting of attention logits. Whether to enable relative embedding as shifting of attention logits.
relative_embedding: flax.linen.Module, default = None relative_embedding: flax.linen.Module, default = None
The module for relative embedding execution, only used when The module for relative embedding execution, only used when
:attr:`enable_relative_embedding=True`. Default is None, which will create :attr:`enable_relative_embedding=True`. Default is ``None``, which will create
an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`. an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
Default: RelativePositionBiases( num_buckets=32, max_distance=128, Default: ``RelativePositionBiases( num_buckets=32, max_distance=128,
num_attention_heads=self.num_attention_heads, dtype=self.dtype, num_attention_heads=self.num_attention_heads, dtype=self.dtype,
embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'), embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
name='relpos_bias') name='relpos_bias')``
enable_rotary_pos_emb: bool, default = False enable_rotary_pos_emb: bool, default = False
Whether to enable rotary position embedding to projected query and key in MHA. Whether to enable rotary position embedding to projected query and key in MHA.
rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000) rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
...@@ -1798,34 +1844,49 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1798,34 +1844,49 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
only used when :attr:`enable_rotary_pos_emb=True` only used when :attr:`enable_rotary_pos_emb=True`
rotary_pos_emb_group_method: str, default = 'consecutive' rotary_pos_emb_group_method: str, default = 'consecutive'
Indicate the method to couple the coordinates. It should be one of Indicate the method to couple the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`, ``['consecutive', 'alternate']``. ``'alternate'`` is to pair index :math:`i` with :math:`i + d/2`,
where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with where :math:`d` is the hidden dimension. ``'consecutive'`` pairs index :math:`i` with
:math:`i + 1`. :math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none' low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj', ``['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
'exclude_output_proj', 'exclude_mlp'] 'exclude_output_proj', 'exclude_mlp']``
low_rank_adaptation_dim: int, default = 32 low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True` :attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output. The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora\_output`. None means no scaling. :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
enable_sequence_parallel: bool, default = False enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot. Whether to enable sequence parallelism to operations except dot.
window_size: Optional[Tuple[int, int]], default = None window_size: Optional[Tuple[int, int]], default = None
Sliding window size. Default value is no sliding window. Sliding window size. Default value is no sliding window.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Softmax type as described in this paper: Softmax type as described in the paper
`Efficient Streaming Language Models with Attention Sinks `Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_. <https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), * ``'vanilla'``:
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention .. math::
('zero sink' and 'learnable sink'). Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
* ``'off-by-one'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
Only supported for fused attention backend. Only supported for fused attention backend.
Optimization parameters Optimization parameters
...@@ -1836,19 +1897,19 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1836,19 +1897,19 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
When > 0.0, applies stochastic depth per sample in the main When > 0.0, applies stochastic depth per sample in the main
path of the residual block. path of the residual block.
fuse_qkv_params: bool, default = True fuse_qkv_params: bool, default = True
If set to True, `TransformerLayer` module exposes a single fused If set to ``True``, ``TransformerLayer`` module exposes a single fused
parameter for query-key-value for self-attention and key-value for parameter for query-key-value for self-attention and key-value for
cross-attention. cross-attention.
transpose_batch_sequence: bool, default = False transpose_batch_sequence: bool, default = False
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. if set to ``True``, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in ``(seqlen, batch, hidden)``, otherwise ``(batch, seqlen, hidden)``.
scale_attn_logits: bool, default = False scale_attn_logits: bool, default = False
Indicate whether to scale attention logits. Indicate whether to scale attention logits.
if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`, if set to ``True``, :math:`\frac{Q \cdot K^T}{\sqrt{head\_dim}}`,
else :math:`Q*K` else :math:`Q \cdot K^T`
scaled_query_init: bool, default = `True` scaled_query_init: bool, default = True
Whether to scale WQ on initialization by :math:`\sqrt{head_dim}` Whether to scale WQ on initialization by :math:`\sqrt{head\_dim}`
""" """
hidden_size: int = 512 hidden_size: int = 512
...@@ -1931,7 +1992,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1931,7 +1992,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
attention_mask : jax.numpy.ndarray, default = None attention_mask : jax.numpy.ndarray, default = None
Boolean tensor used to mask out self-attention softmax input. Boolean tensor used to mask out self-attention softmax input.
:attr:`True` means mask out the corresponding values. :attr:`True` means mask out the corresponding values.
Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'. Ignored when :attr:`self.self_attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
encoder_decoder_mask: jax.numpy.ndarray, default = None encoder_decoder_mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out cross-attention softmax input when Boolean tensor used to mask out cross-attention softmax input when
:attr:`layer_type=TransformerLayerType.DECODER`. :attr:`layer_type=TransformerLayerType.DECODER`.
......
...@@ -7,22 +7,14 @@ ...@@ -7,22 +7,14 @@
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
import functools import functools
from packaging.version import Version as PkgVersion
import torch import torch
from transformer_engine.common import load_framework_extension from transformer_engine.common import load_framework_extension
from transformer_engine.pytorch.torch_version import torch_version
@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release
assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."
load_framework_extension("torch") load_framework_extension("torch")
from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import LayerNormLinear
from transformer_engine.pytorch.module import Linear from transformer_engine.pytorch.module import Linear
......
...@@ -152,25 +152,25 @@ __all__ = ["DotProductAttention"] ...@@ -152,25 +152,25 @@ __all__ = ["DotProductAttention"]
class DotProductAttention(TransformerEngineBaseModule): class DotProductAttention(TransformerEngineBaseModule):
"""Allows the model to jointly attend to information from different r"""Allows the model to jointly attend to information from different
representation subspaces as described in the paper: representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_. `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
.. note:: .. note::
Argument :attr:`attention_mask` in the `forward` call is only used when Argument :attr:`attention_mask` in the ``forward`` call is only used when
:attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`. :attr:`attn_mask_type` includes '"padding"' or ``"arbitrary"``.
.. warning:: .. warning::
FlashAttention uses a non-deterministic algorithm for optimal performance. To observe FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1` deterministic behavior at the cost of performance, use FlashAttention version >= ``2.4.1``
and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`. to disable ``flash-attn`` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
.. note:: .. note::
Transformer Engine stores the FP8 metadata under a `._extra_state` key when checkpointing. Transformer Engine stores the FP8 metadata under a ``._extra_state`` key when checkpointing.
As the FP8 attention support expands from one backend to multiple backends, the location As the FP8 attention support expands from one backend to multiple backends, the location
of that key has also shifted (see `FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_). of that key has also shifted (see `FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_).
...@@ -182,118 +182,137 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -182,118 +182,137 @@ class DotProductAttention(TransformerEngineBaseModule):
kv_channels : Union[int, Tuple[int, int]] kv_channels : Union[int, Tuple[int, int]]
the head size in key and value tensors. If the same, :attr:`kv_channels` can be the head size in key and value tensors. If the same, :attr:`kv_channels` can be
an integer; if not, :attr:`kv_channels` should be a tuple of two integers. an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
num_gqa_groups : Optional[int] = None num_gqa_groups : Optional[int], default = None
number of GQA groups in the transformer layer. number of GQA groups in the transformer layer.
Grouped Query Attention is described in Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_. `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the queries. This only affects the keys and values, not the queries.
GQA-1 is equivalent to Multi-Query Attention GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. is equivalent to MHA, i.e. ``num_gqa_groups = num_attention_heads``.
attention_dropout: float, default = 0.0 attention_dropout : float, default = 0.0
dropout probability for the dropout op during multi-head attention. dropout probability for the dropout op during multi-head attention.
attn_mask_type: str, default = `causal` attn_mask_type : str, default = "causal"
type of attention mask passed into softmax operation, options are "`no_mask`", type of attention mask passed into softmax operation, options are ``"no_mask"``,
"`padding`", "`causal`", "`padding,causal`", "`causal,padding`", ``"padding"``, ``"causal"``, ``"padding,causal"``, ``"causal,padding"``,
"`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and ``"padding_causal"``, ``"causal_bottom_right"``, ``"padding_causal_bottom_right"``, and
"`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`" ``"arbitrary"``, where ``"padding,causal"``, ``"causal,padding"`` and ``"padding_causal"``
are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the
`forward` method. It is useful for cases involving compilation/tracing, e.g. :meth:`forward` method. It is useful for cases involving compilation/tracing, e.g.
ONNX export, and the forward arg is useful for dynamically changing mask types, ONNX export, and the forward arg is useful for dynamically changing mask types,
e.g. a different mask for training and inference. e.g. a different mask for training and inference.
1. For "`no_mask`", no attention mask is applied.
2. For "`causal`", "`causal_bottom_right`", or the causal mask in 1. For ``"no_mask"``, no attention mask is applied.
"`padding_causal`" and "`padding_causal_bottom_right`", Transformer Engine 2. For ``"causal"``, ``"causal_bottom_right"``, or the causal mask in
calculates and applies an upper triangular mask to the softmax input. ``"padding_causal"`` and ``"padding_causal_bottom_right"``, Transformer Engine
No user input is needed. Causal masks without the "`bottom_right`" appendix align calculates and applies an upper triangular mask to the softmax input.
the diagonal line to the top left corner of the softmax matrix. With No user input is needed. Causal masks without the ``"bottom_right"`` appendix align
"`bottom_right`", the causal mask is aligned to the bottom right corner, which is the diagonal line to the top left corner of the softmax matrix. With
often used in inference/KV caching. ``"bottom_right"``, the causal mask is aligned to the bottom right corner, which is
3. For "`padding`", or the padding mask in "`padding_causal`" and often used in inference/KV caching.
"`padding_causal_bottom_right`", users need to provide the locations of padded 3. For ``"padding"``, or the padding mask in ``"padding_causal"`` and
tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape ``"padding_causal_bottom_right"``, users need to provide the locations of padded
[batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both of shape
in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for ``[batch_size + 1]``), or via :attr:`attention_mask` (one tensor for self-attention
cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and of shape ``[batch_size, 1, 1, max_seqlen_q]``, or two tensors in a tuple for
[batch_size, 1, 1, max_seqlen_kv]). cross-attention of shapes ``[batch_size, 1, 1, max_seqlen_q]`` and
4. For "`arbitrary`", users need to provide a mask that is broadcastable to ``[batch_size, 1, 1, max_seqlen_kv]``).
the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. 4. For ``"arbitrary"``, users need to provide a mask that is broadcastable to
window_size: Optional[Tuple[int, int]], default = `None` the shape of softmax input ``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]``.
window_size : Optional[Tuple[int, int]], default = None
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in ``[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window_size[1]] inclusive. Special cases ``(-1, -1)`` and ``(-1, 0)`` mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks window and causal mask specifically. Both ``causal`` and ``causal_bottom_right`` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on
`attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can
be overridden by :attr:`window_size` in `forward` as well. be overridden by :attr:`window_size` in ``forward`` as well.
attention_type: str, default = `self` attention_type : str, default = "self"
type of attention, either "`self`" and "`cross`". type of attention, either ``"self"`` and ``"cross"``.
layer_number: int, default = `None` layer_number : int, default = None
layer number of the current `DotProductAttention` when multiple such modules layer number of the current ``DotProductAttention`` when multiple such modules
are concatenated, for instance in consecutive transformer blocks. are concatenated, for instance in consecutive transformer blocks.
qkv_format: str, default = `sbhd` qkv_format : str, default = "sbhd"
dimension format for `query_layer`, `key_layer` and `value_layer`, dimension format for ``query_layer``, ``key_layer`` and ``value_layer``,
{`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size, {``"sbhd"``, ``"bshd"``, ``"thd"``}. ``s`` stands for the sequence length, ``b`` batch size,
`h` the number of heads, `d` head size, and `t` the total number of tokens ``h`` the number of heads, ``d`` head size, and ``t`` the total number of tokens
in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats in a batch, with ``t = sum(s_i), for i = 0...b-1``. ``"sbhd"`` and ``"bshd"`` formats
are used for when sequences in a batch are of equal length or padded to are used for when sequences in a batch are of equal length or padded to
equal length, and the `thd` format is used for when sequences in a batch equal length, and the ``"thd"`` format is used for when sequences in a batch
have different lengths. Please note that these formats do not reflect how have different lengths. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. tensors ``query_layer``, ``key_layer``, ``value_layer`` are laid out in memory.
For that, please use `get_qkv_layout` to gain the layout information. For that, please use ``get_qkv_layout`` to gain the layout information.
softmax_scale: Optional[float], default = `None` softmax_scale : Optional[float], default = None
softmax scale for the attention scores. If `None`, defaults to softmax scale for the attention scores. If ``None``, defaults to
`1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`. ``1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])``.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' softmax_type : str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper: Softmax type as described in the paper
`Efficient Streaming Language Models with Attention Sinks `Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_. <https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), * ``'vanilla'``:
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention .. math::
('zero sink' and 'learnable sink'). Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
return_max_logit: Optional[bool], default = `False`
* ``'off-by-one'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
return_max_logit : Optional[bool], default = False
If true, returns the maximum attention score that can be used in a Muon optimizer to If true, returns the maximum attention score that can be used in a Muon optimizer to
rescale the Q and K projection weights (see `Muon is Scalable for LLM Training rescale the Q and K projection weights (see `Muon is Scalable for LLM Training
<https://arxiv.org/pdf/2502.16982>`_). <https://arxiv.org/pdf/2502.16982>`_).
max_logit = max(S), where S = mask(Q*K^T*softmax_scale + bias) in shape [b, h, s_q, s_kv], :math:`\text{max_logit} = \max(S)`, where :math:`S = \text{mask}(Q \cdot K^T \cdot \text{softmax_scale} + \text{bias})` of shape ``[b, h, s_q, s_kv]``,
and max_logit is in shape [h]. and :math:`\text{max_logit}` is of shape ``[h]``.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
sequence_parallel : bool, default = `False` sequence_parallel : bool, default = False
if set to `True`, uses sequence parallelism. if set to ``True``, uses sequence parallelism.
tp_size : int, default = 1 tp_size : int, default = 1
tensor parallel world size. tensor parallel world size.
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = None
tensor parallel process group. tensor parallel process group.
cp_group : Union[ProcessGroup, List[ProcessGroup]], default = `None` cp_group : Union[ProcessGroup, List[ProcessGroup]], default = None
context parallel process group. context parallel process group.
ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a". ``ProcessGroup`` is for :attr:`cp_comm_type` of ``"p2p"``, ``"all_gather"``, and ``"a2a"``.
List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0] ``List[ProcessGroup]`` is for :attr:`cp_comm_type` of ``"a2a+p2p"``, where :attr:`cp_group[0]`
and cp_group[1] are for a2a and p2p communications respectively. and :attr:`cp_group[1]` are for ``"a2a"`` and ``"p2p"`` communications respectively.
cp_global_ranks : list of global rank IDs, default = `None` cp_global_ranks : list of global rank IDs, default = None
global rank IDs of GPUs that are in cp_group. global rank IDs of GPUs that are in ``cp_group``.
cp_stream : CUDA stream, default = `None` cp_stream : CUDA stream, default = None
context parallelism splits flash attention into multiple steps for context parallelism splits flash attention into multiple steps for
compute and communication overlapping. To address the wave quantization compute and communication overlapping. To address the wave quantization
issue of each split step, we add an additional CUDA stream so that we issue of each split step, we add an additional CUDA stream so that we
can overlap two flash attention kernels. can overlap two flash attention kernels.
cp_comm_type : str, default = `p2p` cp_comm_type : str, default = "p2p"
inter-gpu communication type for context parallelism. inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p". Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``.
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute. - ``"p2p"``: Exchange KV chunks with P2P communications in ring topology.
"all_gather": All-gather to get full sequence of KV before attention. P2P is async and can be overlapped with attention compute.
The all-gather is not async, and cannot be overlapped. - ``"all_gather"``: All-gather to get full sequence of KV before attention.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP The all-gather is not async, and cannot be overlapped.
group, and gather to get full sequence of QKV. - ``"a2a"``: Like DeepSpeed Ulysses, scatter attention heads across the CP
"a2a+p2p": hierarchical CP implementation. First applying a2a to QKV group, and gather to get full sequence of QKV.
across each CP sub-group (e.g., via NVLink), then exchanging KV with - ``"a2a+p2p"``: hierarchical CP implementation. First applying a2a to QKV
p2p between sub-groups (e.g., via IBLink). across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink).
""" """
def __init__( def __init__(
...@@ -468,8 +487,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -468,8 +487,8 @@ class DotProductAttention(TransformerEngineBaseModule):
): ):
""" """
This function helps to load Transformer Engine 1.6 and 1.7 checkpoints, where FP8 attention This function helps to load Transformer Engine 1.6 and 1.7 checkpoints, where FP8 attention
metadata is stored under the `core_attention.fused_attention._extra_state` key and not the metadata is stored under the ``core_attention.fused_attention._extra_state`` key and not the
`core_attention._extra_state` key. Please see `FP8 checkpoint compatibility ``core_attention._extra_state`` key. Please see `FP8 checkpoint compatibility
<https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_ for more details. <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_ for more details.
""" """
fused_attn_key = False fused_attn_key = False
...@@ -522,25 +541,26 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -522,25 +541,26 @@ class DotProductAttention(TransformerEngineBaseModule):
---------- ----------
cp_group : Union[ProcessGroup, List[ProcessGroup]] cp_group : Union[ProcessGroup, List[ProcessGroup]]
context parallel process group. context parallel process group.
ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a". ``ProcessGroup`` is for :attr:`cp_comm_type` of ``"p2p"``, ``"all_gather"``, and ``"a2a"``.
List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0] ``List[ProcessGroup]`` is for :attr:`cp_comm_type` of ``"a2a+p2p"``, where :attr:`cp_group[0]`
and cp_group[1] are for a2a and p2p communications respectively. and :attr:`cp_group[1]` are for ``"a2a"`` and ``"p2p"`` communications respectively.
cp_global_ranks : List[int] cp_global_ranks : List[int]
list of global ranks in the context group. list of global ranks in the context group.
cp_stream : torch.cuda.Stream cp_stream : torch.cuda.Stream
cuda stream for context parallel execution. cuda stream for context parallel execution.
cp_comm_type : str, default = `p2p` cp_comm_type : str, default = "p2p"
inter-gpu communication type for context parallelism. inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p". Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``.
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute. - ``"p2p"``: Exchange KV chunks with P2P communications in ring topology.
"all_gather": All-gather to get full sequence of KV before attention. P2P is async and can be overlapped with attention compute.
The all-gather is not async, and cannot be overlapped. - ``"all_gather"``: All-gather to get full sequence of KV before attention.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP The all-gather is not async, and cannot be overlapped.
group, and gather to get full sequence of QKV. - ``"a2a"``: Like DeepSpeed Ulysses, scatter attention heads across the CP
"a2a+p2p": hierarchical CP implementation. First applying a2a to QKV group, and gather to get full sequence of QKV.
across each CP sub-group (e.g., via NVLink), then exchanging KV with - ``"a2a+p2p"``: hierarchical CP implementation. First applying a2a to QKV
p2p between sub-groups (e.g., via IBLink). across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink).
""" """
self.cp_group = cp_group self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks self.cp_global_ranks = cp_global_ranks
...@@ -801,13 +821,13 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -801,13 +821,13 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8_output: Optional[bool] = False, fp8_output: Optional[bool] = False,
num_splits: Optional[int] = 1, num_splits: Optional[int] = 1,
) -> torch.Tensor: ) -> torch.Tensor:
""" r"""
Dot Product Attention Layer. Dot Product Attention Layer.
.. note:: .. note::
Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type` Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
includes '"padding"' or `"arbitrary"`. includes ``"padding"`` or ``"arbitrary"``.
.. note:: .. note::
...@@ -846,24 +866,24 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -846,24 +866,24 @@ class DotProductAttention(TransformerEngineBaseModule):
Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask` Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`
(which will be converted to :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`), to provide (which will be converted to :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`), to provide
the real sequence length information. For example, a batch of 3 sequences the real sequence length information. For example, a batch of 3 sequences
[a a a b b c c c c] can be padded to [a a a PAD b b PAD PAD c c c c], and the cumulative ``[a a a b b c c c c]`` can be padded to ``[a a a PAD b b PAD PAD c c c c]``, and the cumulative
sequence length tensors would be sequence length tensors would be
:attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention. :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = ``[0, 3, 5, 9]`` for self-attention.
2. Do not perform padding on training data. Use :attr:`qkv_format` = "thd" and 2. Do not perform padding on training data. Use :attr:`qkv_format` = "thd" and
:attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}. :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`, Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`,
as in option 1. For example, a batch of 3 sequences [a a a b b c c c c] can be processed as in option 1. For example, a batch of 3 sequences ``[a a a b b c c c c]`` can be processed
without any padding, and the sequence length tensors would be without any padding, and the sequence length tensors would be
:attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention. :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = ``[0, 3, 5, 9]`` for self-attention.
In certain use cases, a varying number of identifier tokens are inserted between In certain use cases, a varying number of identifier tokens are inserted between
sequences. These tokens do not participate in the attention calculation. sequences. These tokens do not participate in the attention calculation.
:attr:`cu_seqlens_q_padded` and :attr:`cu_seqlens_kv_padded` must be specified :attr:`cu_seqlens_q_padded` and :attr:`cu_seqlens_kv_padded` must be specified
in such cases to correctly identify the start and end of each sequence in a batch. in such cases to correctly identify the start and end of each sequence in a batch.
For example, a batch of 3 sequences [a a a 1 b b 2 2 c c c c 3] would have For example, a batch of 3 sequences ``[a a a 1 b b 2 2 c c c c 3]`` would have
:attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9], and :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = ``[0, 3, 5, 9]``, and
:attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = [0, 4, 8, 13] :attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = ``[0, 4, 8, 13]``
for self-attention. for self-attention.
.. note:: .. note::
...@@ -898,81 +918,81 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -898,81 +918,81 @@ class DotProductAttention(TransformerEngineBaseModule):
value_layer : torch.Tensor value_layer : torch.Tensor
Value tensor. Value tensor.
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensor(s) used to mask out attention softmax input. default = None. Boolean tensor(s) used to mask out attention softmax input.
It should be `None` for causal masks and "`no_mask`". For padding masks, it should be It should be ``None`` for causal masks and ``"no_mask"``. For padding masks, it should be
a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of a single tensor of ``[batch_size, 1, 1, seqlen_q]`` for self-attention, and a tuple of
two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] two tensors of shapes ``[batch_size, 1, 1, seqlen_q]`` and ``[batch_size, 1, 1, seqlen_kv]``
for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable for cross-attention. For ``"arbitrary"`` mask, it should be of a shape broadcastable
to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means to ``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]``. A ``True`` value means
the corresponding position is masked out and a `False` means that position the corresponding position is masked out and a ``False`` means that position
is allowed to participate in attention. is allowed to participate in attention.
qkv_format: str, default = `None` qkv_format: str, default = None
If provided, overrides :attr:`qkv_format` from initialization. If provided, overrides :attr:`qkv_format` from initialization.
cu_seqlens_q: Optional[torch.Tensor], default = `None` cu_seqlens_q: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, Cumulative sum of sequence lengths (without offset) in a batch for ``query_layer``,
with shape [batch_size + 1] and dtype torch.int32. with shape [batch_size + 1] and dtype torch.int32.
See :ref:`note<cu_seqlens note>` for more details. See :ref:`note<cu_seqlens note>` for more details.
cu_seqlens_kv: Optional[torch.Tensor], default = `None` cu_seqlens_kv: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` Cumulative sum of sequence lengths (without offset) in a batch for ``key_layer``
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. and ``value_layer``, with shape [batch_size + 1] and dtype torch.int32.
See :ref:`note<cu_seqlens note>` for more details. See :ref:`note<cu_seqlens note>` for more details.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None` cu_seqlens_q_padded: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (with offset) in a batch for Cumulative sum of sequence lengths (with offset) in a batch for
`query_layer`, with shape [batch_size + 1] and dtype torch.int32. ``query_layer``, with shape ``[batch_size + 1]`` and dtype torch.int32.
When there is no padding between sequences in a batch, When there is no padding between sequences in a batch,
`cu_seqlens_q_padded = cu_seqlens_q`. :attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_q`.
See :ref:`note<cu_seqlens note>` for more details. See :ref:`note<cu_seqlens note>` for more details.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None` cu_seqlens_kv_padded: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer` Cumulative sum of sequence lengths (with offset) in a batch for ``key_layer``
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. and ``value_layer``, with shape ``[batch_size + 1]`` and dtype torch.int32.
When there is no padding between sequences in a batch, When there is no padding between sequences in a batch,
`cu_seqlens_kv_padded = cu_seqlens_kv`. :attr:`cu_seqlens_kv_padded` = :attr:`cu_seqlens_kv`.
See :ref:`note<cu_seqlens note>` for more details. See :ref:`note<cu_seqlens note>` for more details.
max_seqlen_q: Optional[int], default = `None` max_seqlen_q: Optional[int], default = None
Maximum sequence length in `query_layer`. Maximum sequence length in ``query_layer``.
See :ref:`note<max_seqlen note>` for more details. See :ref:`note<max_seqlen note>` for more details.
max_seqlen_kv: Optional[int], default = `None` max_seqlen_kv: Optional[int], default = None
Maximum sequence length in `key_layer` and `value_layer`. Maximum sequence length in ``key_layer`` and ``value_layer``.
See :ref:`note<max_seqlen note>` for more details. See :ref:`note<max_seqlen note>` for more details.
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding', attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right', 'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
'arbitrary'}, default = `None`. Type of attention mask passed into 'arbitrary'}, default = None. Type of attention mask passed into
softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal' softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal'
are equivalent. By default, causal masks are aligned to the top left corner are equivalent. By default, causal masks are aligned to the top left corner
of the softmax matrix. When "`bottom_right`" is specified in the mask type, of the softmax matrix. When ``"bottom_right"`` is specified in the mask type,
causal masks are aligned to the bottom right corner. causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None` window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention. Sliding window size for local attention.
checkpoint_core_attention : bool, default = `False` checkpoint_core_attention : bool, default = False
If true, forward activations for attention are recomputed If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until otherwise be occupied to store the forward activations until
backprop. backprop.
core_attention_bias_type: str, default = `no_bias` core_attention_bias_type: str, default = "no_bias"
Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`} Bias type, {``"no_bias"``, ``"pre_scale_bias"``, ``"post_scale_bias"``, ``"alibi"``}
core_attention_bias: Optional[torch.Tensor], default = `None` core_attention_bias: Optional[torch.Tensor], default = None
Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv]. Bias tensor for :math:`Q \cdot K^T`, shape ``[1, num_head, max_seqlen_q, max_seqlen_kv]``.
It should be 'None' for 'no_bias' and 'alibi' bias types. It should be ``None`` for ``"no_bias"`` and ``"alibi"`` bias types.
alibi_slopes: Optional[torch.Tensor], default = `None` alibi_slopes: Optional[torch.Tensor], default = None
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. ALiBi slopes in FP32 and shape ``[nheads]`` or ``[batch_size, nheads]``.
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j. to the attention score of query i and key j.
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = True
Whether to use the fast path to set output tensors to 0 or not. Whether to use the fast path to set output tensors to 0 or not.
inference_params: Optional[InferenceParams], default = `None` inference_params: Optional[InferenceParams], default = None
Optimizes execution performance during inference by caching Keys and Values of the Optimizes execution performance during inference by caching Keys and Values of the
current decoding iteration. These cached values are appended to the K and V values current decoding iteration. These cached values are appended to the K and V values
computed in previous iterations, eliminating the need to recalculate them for the computed in previous iterations, eliminating the need to recalculate them for the
entire sequence. entire sequence.
Initialization of `inference_params` is required prior to use to ensure sufficient Initialization of ``inference_params`` is required prior to use to ensure sufficient
memory allocation. memory allocation.
Adjustments of the sequence_len_offset should be done after a complete forward pass. Adjustments of the sequence_len_offset should be done after a complete forward pass.
If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand. If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient. Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
pad_between_seqs: Optional[bool], default = `None` pad_between_seqs: Optional[bool], default = None
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. If ``None``, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch. If ``True``, there are padding tokens between individual sequences in a packed batch.
fp8_output: Optional[bool], default = `False` fp8_output: Optional[bool], default = False
Whether to enforce output to be in FP8 or not. Whether to enforce output to be in FP8 or not.
num_splits: Optional[int], default = 1 num_splits: Optional[int], default = 1
Optional split control for FlashAttention-3 only. When set, this value is forwarded Optional split control for FlashAttention-3 only. When set, this value is forwarded
......
...@@ -175,65 +175,65 @@ class AttentionParams: ...@@ -175,65 +175,65 @@ class AttentionParams:
Parameters Parameters
---------- ----------
qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor` qkv_type : Union[torch.Tensor, Float8Tensor], default = torch.Tensor
Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}. Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
qkv_dtype: torch.dtype, default = `torch.bfloat16` qkv_dtype : torch.dtype, default = torch.bfloat16
Data type of query/key/value tensors. Data type of query/key/value tensors.
qkv_layout: str, default = "sbh3d" qkv_layout : str, default = "sbh3d"
Query/key/value tensor memory layout. Query/key/value tensor memory layout.
batch_size: int, default = 1 batch_size : int, default = 1
Batch size. Batch size.
num_heads: int, default = 16 num_heads : int, default = 16
Number of attention heads in the query tensor. Number of attention heads in the query tensor.
num_gqa_groups: int, default = 16 num_gqa_groups : int, default = 16
Number of attention heads in key and value tensors. Number of attention heads in key and value tensors.
max_seqlen_q: int, default = 128 max_seqlen_q : int, default = 128
Maximum sequence length of the query tensor. Maximum sequence length of the query tensor.
max_seqlen_kv: int, default = 128 max_seqlen_kv : int, default = 128
Maximum sequence length of the key and value tensors. Maximum sequence length of the key and value tensors.
head_dim_qk: int, default = 64 head_dim_qk : int, default = 64
The size of each attention head in query and key tensors. The size of each attention head in query and key tensors.
head_dim_v: int, default = 64 head_dim_v : int, default = 64
The size of each attention head in the value tensor. The size of each attention head in the value tensor.
attn_mask_type: str, default = `no_mask` attn_mask_type : str, default = no_mask
Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`, Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size: Tuple[int, int], default = None window_size : Tuple[int, int], default = None
Sliding window attention size. Sliding window attention size.
alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None` alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None
Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
core_attention_bias_type: str, default = `no_bias` core_attention_bias_type : str, default = no_bias
Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}. Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
core_attention_bias_shape: str, default = `1hss` core_attention_bias_shape : str, default = 1hss
Attention bias shape, {`1hss`, `b1ss`, `bhss`}. Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
core_attention_bias_requires_grad: bool, default = `True` core_attention_bias_requires_grad : bool, default = True
Whether attention bias requires gradient. Whether attention bias requires gradient.
pad_between_seqs: bool, default = `False` pad_between_seqs : bool, default = False
Whether there is padding between sequences in a batch. Whether there is padding between sequences in a batch.
This only applies to `qkv_format=thd`. This only applies to `qkv_format=thd`.
attention_dropout: float, default = 0.0 attention_dropout : float, default = 0.0
Attention dropout. Attention dropout.
context_parallel: bool, default = `False` context_parallel : bool, default = False
Whether context parallelism is used or not. Whether context parallelism is used or not.
cp_comm_type: str, default = "p2p" cp_comm_type : str, default = "p2p"
The communication type of context parallelism. The communication type of context parallelism.
deterministic: bool, default = `False` deterministic : bool, default = False
Whether to run `DotProductAttention` with determinism or not. Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True` is_training : bool, default = True
Whether in training mode (`True`) or inference mode (`False`) Whether in training mode (`True`) or inference mode (`False`)
fp8: bool, default = `False` fp8 : bool, default = False
Whether `DotProductAttention` is in an `autocast` region. Whether `DotProductAttention` is in an `autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None` fp8_meta : Optional[Dict[str Any]], default = None
The FP8 metadata tensor of `DotProductAttention`. The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None` inference_params : Optional[InferenceParams], default = None
Inference-related parameters. See InferenceParams for details. Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla" softmax_type : str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details. The type of softmax operation. See DotProductAttention for details.
return_max_logit: bool, default = `False` return_max_logit : bool, default = False
Whether to output max_logit. Whether to output max_logit.
cuda_graph: bool, default = `False` cuda_graph : bool, default = `False`
Whether support for cuda graph capture is needed or not. Whether support for cuda graph capture is needed or not.
num_splits: int, default = 1 num_splits : int, default = 1
The number of kernels to split attention to. The number of kernels to split attention to.
""" """
...@@ -298,15 +298,15 @@ def get_attention_backend( ...@@ -298,15 +298,15 @@ def get_attention_backend(
Returns Returns
---------- ----------
use_flash_attention: bool use_flash_attention : bool
Whether the `FlashAttention` backend has been selected. Whether the `FlashAttention` backend has been selected.
use_fused_attention: bool use_fused_attention : bool
Whether the `FusedAttention` backend has been selected. Whether the `FusedAttention` backend has been selected.
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend : tex.NVTE_Fused_Attn_Backend
If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`. If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
use_unfused_attention: bool use_unfused_attention : bool
Whether the `UnfusedDotProductAttention` backend has been selected. Whether the `UnfusedDotProductAttention` backend has been selected.
available_backends: List[bool] available_backends : List[bool]
All available backends that could support the provided input. A list of Booleans All available backends that could support the provided input. A list of Booleans
in the form of [use_flash_attention, use_fused_attention, use_unfused_attention]. in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
""" """
...@@ -835,8 +835,8 @@ def get_attention_backend( ...@@ -835,8 +835,8 @@ def get_attention_backend(
# ---------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------
# no_mask | None | All # no_mask | None | All
# padding | | All # padding | | All
# self-attention | One tensor in shape [b, 1, 1, sq] | # self-attention | One tensor of shape [b, 1, 1, sq] |
# cross-attention | Tuple of two tensors in shapes | # cross-attention | Tuple of two tensors of shapes |
# | [b, 1, 1, sq] and [b, 1, 1, skv] | # | [b, 1, 1, sq] and [b, 1, 1, skv] |
# causal | None | # causal | None |
# self-attention | | All # self-attention | | All
...@@ -846,7 +846,7 @@ def get_attention_backend( ...@@ -846,7 +846,7 @@ def get_attention_backend(
# cross-attention | | FusedAttention, UnfusedDotProductAttention # cross-attention | | FusedAttention, UnfusedDotProductAttention
# causal_bottom_right | None | All # causal_bottom_right | None | All
# padding_causal_bottom_right | Same as "padding" | All # padding_causal_bottom_right | Same as "padding" | All
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention # arbitrary | One tensor of shape broadcastable to | UnfusedDotProductAttention
# | [b, h, sq, skv] | # | [b, h, sq, skv] |
if attn_mask_type == "arbitrary": if attn_mask_type == "arbitrary":
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
...@@ -1271,42 +1271,42 @@ def get_full_mask( ...@@ -1271,42 +1271,42 @@ def get_full_mask(
Parameters Parameters
---------- ----------
max_seqlen_q: int max_seqlen_q : int
Maximum sequence length for queries. Maximum sequence length for queries.
max_seqlen_kv: int max_seqlen_kv : int
Maximum sequence length for keys and values. Maximum sequence length for keys and values.
attn_mask_type: str, default = `no_mask` attn_mask_type : str, default = no_mask
Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`", Attention mask type, {``"no_mask"``, ``"padding"``, ``"causal"``, ``"padding_causal"``,
"`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"} ``"causal_bottom_right"``, ``"padding_causal_bottom_right"``, ``"arbitrary"``}
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
default = `None` default = None
Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
for the requirements of `attention_mask` for different `attn_mask_type`s. for the requirements of `attention_mask` for different `attn_mask_type`s.
window_size: Tuple[int, int], default = `None` window_size : Tuple[int, int], default = None
Sliding window size for local attention, where query at position i attends to keys Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`. `attn_mask_type`.
attention_type: str, default = "self" attention_type : str, default = "self"
Attention type, {"self", "cross"} Attention type, {"self", "cross"}
bottom_right_alignment: bool, default = `True` bottom_right_alignment : bool, default = True
Whether to align the diagonal of the sliding window attention to the bottom right (`True`) Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
specifies "causal" or "causal_bottom_right". specifies "causal" or "causal_bottom_right".
Returns Returns
---------- ----------
attn_mask_type: str attn_mask_type : str
For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type` For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
attention_mask: torch.Tensor attention_mask : torch.Tensor
The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size` The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
actual_seqlens_q: torch.Tensor actual_seqlens_q : torch.Tensor
For padding masks, the actual sequence lengths for queries, in shape [batch_size]. For padding masks, the actual sequence lengths for queries, of shape [batch_size].
For other masks, `None`. For other masks, `None`.
actual_seqlens_kv: Optional[torch.Tensor], default = `None` actual_seqlens_kv : Optional[torch.Tensor], default = None
For padding masks, the actual sequence lengths for keys and values, in shape [batch_size]. For padding masks, the actual sequence lengths for keys and values, of shape [batch_size].
For other masks, `None`. For other masks, `None`.
""" """
# perform basic checks # perform basic checks
...@@ -1392,29 +1392,29 @@ def get_alibi( ...@@ -1392,29 +1392,29 @@ def get_alibi(
""" """
Parameters Parameters
---------- ----------
num_heads: int num_heads : int
Number of heads. Number of heads.
max_seqlen_q: int max_seqlen_q : int
Maximum sequence length for queries. Maximum sequence length for queries.
max_seqlen_kv: int max_seqlen_kv : int
Maximum sequence length for keys and values. Maximum sequence length for keys and values.
actual_seqlens_q: Optional[torch.Tensor], default = `None` actual_seqlens_q : Optional[torch.Tensor], default = None
Actual sequence lengths for queries, in shape [batch_size]. Actual sequence lengths for queries, of shape [batch_size].
actual_seqlens_kv: Optional[torch.Tensor], default = `None` actual_seqlens_kv : Optional[torch.Tensor], default = None
Actual sequence lengths for keys and values, in shape [batch_size]. Actual sequence lengths for keys and values, of shape [batch_size].
alibi_slopes: Optional[torch.Tensor], default = `None` alibi_slopes : Optional[torch.Tensor], default = None
Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. Custom ALiBi slopes, FP32, CUDA tensor, of shape [num_heads] or [batch_size, num_heads].
bias_dtype: Optional[torch.dtype], default = `None` bias_dtype : Optional[torch.dtype], default = None
Dtype of the generated ALiBi bias. If None, use torch.float32. Dtype of the generated ALiBi bias. If None, use torch.float32.
bottom_right_alignment: bool, default = `True` bottom_right_alignment : bool, default = True
Whether to align the diagonal of the ALiBi bias to the bottom right corner of Whether to align the diagonal of the ALiBi bias to the bottom right corner of
the matrix (`True`) or top left (`False`). the matrix (`True`) or top left (`False`).
Returns Returns
---------- ----------
alibi_slopes: torch.Tensor alibi_slopes : torch.Tensor
ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads]. ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
alibi_bias: torch.Tensor alibi_bias : torch.Tensor
ALiBi bias in FP32 or `bias_dtype`. Its shape is ALiBi bias in FP32 or `bias_dtype`. Its shape is
(1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape, (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
...@@ -1818,18 +1818,18 @@ def get_qkv_format( ...@@ -1818,18 +1818,18 @@ def get_qkv_format(
Parameters Parameters
---------- ----------
qkv_layout: str qkv_layout : str
Memory layout of `q`, `k` and `v`. See get_qkv_layout() for more details. Memory layout of `q`, `k` and `v`. See get_qkv_layout() for more details.
inference_params: InferenceParams, default = `None` inference_params : InferenceParams, default = None
InferenceParams related to KV caching. InferenceParams related to KV caching.
Returns Returns
---------- ----------
qkv_format: str, default = `sbhd` qkv_format : str, default = sbhd
Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}.
q_format: str q_format : str
Format of the `q` tensor, {`bshd`, `sbhd`, `thd`}. Format of the `q` tensor, {`bshd`, `sbhd`, `thd`}.
kv_format: str kv_format : str
Format of the `k` and `v` tensors, {`bshd`, `sbhd`, `thd`}. Format of the `k` and `v` tensors, {`bshd`, `sbhd`, `thd`}.
""" """
splited = qkv_layout.replace("paged_kv_", "").split("_") splited = qkv_layout.replace("paged_kv_", "").split("_")
...@@ -1855,23 +1855,23 @@ def get_qkv_layout( ...@@ -1855,23 +1855,23 @@ def get_qkv_layout(
Parameters Parameters
---------- ----------
q: torch.Tensor q : torch.Tensor
Query tensor. Query tensor.
k: torch.Tensor k : torch.Tensor
Key tensor. Key tensor.
v: torch.Tensor v : torch.Tensor
Value tensor. Value tensor.
qkv_format: str, default = `sbhd` qkv_format : str, default = sbhd
Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
the sequence length dimension, `b` batch size, `h` the number of attention heads, the sequence length dimension, `b` batch size, `h` the number of attention heads,
`d` head size, and `t` the total number of tokens in a batch, i.e. `d` head size, and `t` the total number of tokens in a batch, i.e.
`t = sum(s_i) for i = 0...b-1`. `t = sum(s_i) for i = 0...b-1`.
inference_params: InferenceParams, default = `None` inference_params : InferenceParams, default = None
InferenceParams related to KV caching. InferenceParams related to KV caching.
Returns Returns
---------- ----------
qkv_layout: str qkv_layout : str
Memory layout of `q`, `k` and `v`. Each `qkv_layout` maps to a pair of `q_format` and Memory layout of `q`, `k` and `v`. Each `qkv_layout` maps to a pair of `q_format` and
`kv_format` in {`bshd`, `sbhd`, `thd`}. The `paged_kv_` prefix is used to indicate that `kv_format` in {`bshd`, `sbhd`, `thd`}. The `paged_kv_` prefix is used to indicate that
paged KV caching is in play. A few examples of the layouts are as follows. paged KV caching is in play. A few examples of the layouts are as follows.
...@@ -1893,18 +1893,18 @@ def get_qkv_layout( ...@@ -1893,18 +1893,18 @@ def get_qkv_layout(
`thd_2bshd`: {`thd_bshd_bshd`, `paged_kv_thd_bshd_bshd`} `thd_2bshd`: {`thd_bshd_bshd`, `paged_kv_thd_bshd_bshd`}
`thd_2sbhd`: {`thd_sbhd_sbhd`, `paged_kv_thd_sbhd_sbhd`} `thd_2sbhd`: {`thd_sbhd_sbhd`, `paged_kv_thd_sbhd_sbhd`}
q: torch.Tensor q : torch.Tensor
Query tensor. It may be different from input `q` as we try to fit tensors to Query tensor. It may be different from input `q` as we try to fit tensors to
a supported layout. a supported layout.
k: torch.Tensor k : torch.Tensor
Key tensor. It may be different from input `k` as we try to fit tensors to Key tensor. It may be different from input `k` as we try to fit tensors to
a supported layout. a supported layout.
v: torch.Tensor v : torch.Tensor
Value tensor. It may be different from input `v` as we try to fit tensors to Value tensor. It may be different from input `v` as we try to fit tensors to
a supported layout. a supported layout.
q_format: str q_format : str
Format of the query tensor, {`bshd`, `sbhd`, `thd`}. Format of the query tensor, {`bshd`, `sbhd`, `thd`}.
kv_format: str kv_format : str
Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}. Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}.
""" """
......
...@@ -98,29 +98,29 @@ class InferenceParams: ...@@ -98,29 +98,29 @@ class InferenceParams:
Parameters Parameters
---------- ----------
max_batch_size: int max_batch_size : int
Maximum batch size in inference Maximum batch size in inference
max_sequence_length: int max_sequence_length : int
Maximum sequence length in inference Maximum sequence length in inference
num_heads_kv: int num_heads_kv : int
Number of attention heads in keys and values Number of attention heads in keys and values
head_dim_k: int head_dim_k : int
Head size for keys Head size for keys
dtype: torch.dtype dtype : torch.dtype
Data type of the KV cache Data type of the KV cache
head_dim_v: int, default = None head_dim_v : int, default = None
Head size for values. If None, initialized as head_dim_k. Head size for values. If None, initialized as head_dim_k.
is_paged: bool, default = False is_paged : bool, default = False
Whether the KV cache is paged (True) or non-paged (False) Whether the KV cache is paged (True) or non-paged (False)
total_num_pages: int, default = None total_num_pages : int, default = None
Total number of pages in the KV cache. Required for is_paged = True. Total number of pages in the KV cache. Required for is_paged = True.
page_size: int, default = None page_size : int, default = None
Page size of the KV cache. Required for is_paged = True. Page size of the KV cache. Required for is_paged = True.
max_ctx_len: int, default = None max_ctx_len : int, default = None
Maximum context length in inference. 1 <= max_ctx_len <= max_sequence_length. Maximum context length in inference. 1 <= max_ctx_len <= max_sequence_length.
qkv_format: str, default = "bshd" qkv_format : str, default = "bshd"
Format of the incoming query/key/value tensors in current iteration Format of the incoming query/key/value tensors in current iteration
custom_cache_manager: KVCacheManager, default = None custom_cache_manager : KVCacheManager, default = None
Custom cache manager, with KVCacheManager as the base class. Custom cache manager, with KVCacheManager as the base class.
""" """
...@@ -525,9 +525,9 @@ class NonPagedKVCacheManager(KVCacheManager): ...@@ -525,9 +525,9 @@ class NonPagedKVCacheManager(KVCacheManager):
new_v: torch.Tensor new_v: torch.Tensor
New value tokens for layer_number in current inference iteration New value tokens for layer_number in current inference iteration
cu_new_seqlens: torch.Tensor cu_new_seqlens: torch.Tensor
Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] Cumulative sequence lengths for new_k and new_v, of shape [batch_size + 1]
cu_cached_seqlens: torch.Tensor cu_cached_seqlens: torch.Tensor
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), of shape [batch_size + 1]
qkv_format: str qkv_format: str
Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
...@@ -701,7 +701,7 @@ class PagedKVCacheManager(KVCacheManager): ...@@ -701,7 +701,7 @@ class PagedKVCacheManager(KVCacheManager):
return [x.page_id for x in self.allocated_pages[seq]] return [x.page_id for x in self.allocated_pages[seq]]
def get_page_table(self, sequences: List[int]): def get_page_table(self, sequences: List[int]):
"""Get the page table, in shape [batch_size, max_pages_per_seq]""" """Get the page table, of shape [batch_size, max_pages_per_seq]"""
page_table = torch.Tensor( page_table = torch.Tensor(
[ [
self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq)) self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq))
...@@ -783,9 +783,9 @@ class PagedKVCacheManager(KVCacheManager): ...@@ -783,9 +783,9 @@ class PagedKVCacheManager(KVCacheManager):
new_v: torch.Tensor new_v: torch.Tensor
New value tokens for layer_number in current inference iteration New value tokens for layer_number in current inference iteration
cu_new_seqlens: torch.Tensor cu_new_seqlens: torch.Tensor
Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] Cumulative sequence lengths for new_k and new_v, of shape [batch_size + 1]
cu_cached_seqlens: torch.Tensor cu_cached_seqlens: torch.Tensor
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), of shape [batch_size + 1]
qkv_format: str qkv_format: str
Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
......
...@@ -50,8 +50,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -50,8 +50,8 @@ class MultiheadAttention(torch.nn.Module):
.. note:: .. note::
Argument :attr:`attention_mask` in the `forward` call is only used when Argument :attr:`attention_mask` in the :meth:`forward() <MultiheadAttention.forward>` method is only used when
:attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`. :attr:`attn_mask_type` includes ``"padding"`` or ``"arbitrary"``.
Parameters Parameters
---------- ----------
...@@ -59,57 +59,56 @@ class MultiheadAttention(torch.nn.Module): ...@@ -59,57 +59,56 @@ class MultiheadAttention(torch.nn.Module):
size of each input sample. size of each input sample.
num_attention_heads : int num_attention_heads : int
number of attention heads in the transformer layer. number of attention heads in the transformer layer.
kv_channels: int, default = `None` kv_channels : int, default = None
number of key-value channels. defaults to number of key-value channels. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`. :attr:`hidden_size` / :attr:`num_attention_heads` if ``None``.
attention_dropout: float, default = 0.1 attention_dropout : float, default = 0.1
dropout probability for the dropout op during multi-head attention. dropout probability for the dropout op during multi-head attention.
layernorm_epsilon : float, default = 1e-5 layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization a value added to the denominator of layer normalization
for numerical stability. for numerical stability.
init_method : Callable, default = `None` init_method : Callable, default = None
used for initializing weights of QKV and FC1 weights in the following way: used for initializing weights of QKV and FC1 weights in the following way:
`init_method(weight)`. When set to `None`, defaults to ``init_method(weight)``. When set to ``None``, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`. ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
output_layer_init_method : Callable, default = `None` output_layer_init_method : Callable, default = None
used for initializing weights of PROJ and FC2 in the following way: used for initializing weights of PROJ and FC2 in the following way:
`output_layer_init_method(weight)`. When set to `None`, defaults to ``output_layer_init_method(weight)``. When set to ``None``, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`. ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
layer_number: int, default = `None` layer_number : int, default = None
layer number of the current `TransformerLayer` when multiple such modules are layer number of the current ``TransformerLayer`` when multiple such modules are
concatenated to form a transformer block. concatenated to form a transformer block.
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right', attn_mask_type : {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
'padding_causal_bottom_right','arbitrary'}, 'padding_causal_bottom_right','arbitrary'},
default = `causal` default = "causal"
type of attention mask passed into softmax operation. Overridden by type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward :attr:`attn_mask_type` in the :meth:`forward` method. The :meth:`forward`
arg is useful for dynamically changing mask types, e.g. a different arg is useful for dynamically changing mask types, e.g. a different
mask for training and inference. The init arg is useful for cases mask for training and inference. The :meth:`__init__` arg is useful for cases
involving compilation/tracing, e.g. ONNX export. involving compilation/tracing, e.g. ONNX export.
window_size: Optional[Tuple[int, int]], default = `None` window_size : Optional[Tuple[int, int]], default = None
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in ``[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]]`` inclusive. Special cases ``(-1, -1)`` and ``(-1, 0)`` mean no sliding
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. Both ``"causal"`` and ``"causal_bottom_right"`` masks
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can
`attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can be overridden by :attr:`window_size` in :meth:`forward` as well.
be overridden by :attr:`window_size` in `forward` as well. num_gqa_groups : int, default = None
num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer. number of GQA groups in the transformer layer.
Grouped Query Attention is described in Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_. `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys. This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. is equivalent to MHA, i.e. ``num_gqa_groups = num_attention_heads``.
return_layernorm_output : bool, default = `False` return_layernorm_output : bool, default = False
if set to `True`, output of layernorm is returned from the forward if set to ``True``, output of layernorm is returned from the :meth:`forward` method
together with the output of the linear transformation. together with the output of the linear transformation.
Example use case: residual connection for transformer module is Example use case: residual connection for transformer module is
taken post layernorm. taken post layernorm.
input_layernorm: bool, default = `False` input_layernorm : bool, default = False
if set to `True`, layer normalization to the input is applied. if set to ``True``, layer normalization to the input is applied.
attention_type: { 'self', 'cross' }, default = 'self' attention_type : { 'self', 'cross' }, default = 'self'
type of attention applied. type of attention applied.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
...@@ -120,103 +119,118 @@ class MultiheadAttention(torch.nn.Module): ...@@ -120,103 +119,118 @@ class MultiheadAttention(torch.nn.Module):
(1 + \gamma) + \beta (1 + \gamma) + \beta
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied. type of normalization applied.
qkv_weight_interleaved : bool, default = `True` qkv_weight_interleaved : bool, default = True
if set to `False`, the QKV weight is interpreted as a concatenation of if set to ``False``, the QKV weight is interpreted as a concatenation of
query, key, and value weights along the `0th` dimension. The default query, key, and value weights along the ``0th`` dimension. The default
interpretation is that the individual `q`, `k`, and `v` weights for each interpretation is that the individual ``q``, ``k``, and ``v`` weights for each
attention head are interleaved. This parameter is set to `False` when attention head are interleaved. This parameter is set to ``False`` when
using :attr:`fuse_qkv_params=False`. using :attr:`fuse_qkv_params=False`.
rotary_pos_interleaved : bool, default = `False` rotary_pos_interleaved : bool, default = False
whether to use interleaved rotary position embeddings. whether to use interleaved rotary position embeddings.
bias : bool, default = `True` bias : bool, default = True
if set to `False`, the transformer layer will not learn any additive biases. if set to ``False``, the transformer layer will not learn any additive biases.
device : Union[torch.device, str], default = "cuda" device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
qkv_format: str, default = `sbhd` qkv_format : str, default = "sbhd"
dimension format for `query_layer`, `key_layer` and `value_layer`, dimension format for ``query_layer``, ``key_layer`` and ``value_layer``,
{`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size, {``"sbhd"``, ``"bshd"``}. ``s`` stands for the sequence length, ``b`` batch size,
`h` the number of heads and `d` head size. `sbhd` and `bshd` formats ``h`` the number of heads and ``d`` head size. ``"sbhd"`` and ``"bshd"`` formats
are used for when sequences in a batch are of equal length or padded to are used for when sequences in a batch are of equal length or padded to
equal length. Please note that these formats do not reflect how equal length. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. tensors ``query_layer``, ``key_layer``, ``value_layer`` are laid out in memory.
For that, please use `get_qkv_layout` to gain the layout information. For that, please use ``get_qkv_layout`` to gain the layout information.
name: str, default = `None` name : str, default = None
name of the module, currently used for debugging purposes. name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' softmax_type : str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper: Softmax type as described in the paper
`Efficient Streaming Language Models with Attention Sinks `Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_. <https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), * ``'vanilla'``:
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention .. math::
('zero sink' and 'learnable sink'). S_{:,:,:,i} = = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
* ``'off-by-one'``:
.. math::
S_{:,:,:,i} = = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
S_{:,:,:,i} = = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
Parallelism parameters Parallelism parameters
---------------------- ----------------------
set_parallel_mode : bool, default = `False` set_parallel_mode : bool, default = False
if set to `True`, QKV and FC1 layers are used as Column Parallel if set to ``True``, QKV and FC1 layers are used as Column Parallel
whereas PROJ and FC2 is used as Row Parallel as described whereas PROJ and FC2 is used as Row Parallel as described
`here <https://arxiv.org/pdf/1909.08053.pdf>`_. `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default = `False` sequence_parallel : bool, default = False
if set to `True`, uses sequence parallelism. if set to ``True``, uses sequence parallelism.
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = None
tensor parallel process group. tensor parallel process group.
tp_size : int, default = 1 tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the ``set_tensor_parallel_group(tp_group)`` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives. parallel collectives.
Optimization parameters Optimization parameters
----------------------- -----------------------
fuse_wgrad_accumulation : bool, default = 'False' fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of if set to ``True``, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional ``main_grad`` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular ``grad``) which is a pre-allocated buffer of the correct
size to accumulate gradients in. size to accumulate gradients in.
params_dtype : torch.dtype, default = `torch.get_default_dtype()` params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
return_bias : bool, default = `False` return_bias : bool, default = False
when set to `True`, this module will not apply the additive bias itself, but when set to ``True``, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the instead return the bias value during the :meth:`forward` method together with the
output of the linear transformation :math:`y = xA^T`. This is useful when output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations. the bias addition can be fused to subsequent operations.
fuse_qkv_params: bool, default = 'False' fuse_qkv_params : bool, default = 'False'
if set to `True`, `TransformerLayer` module exposes a single fused if set to ``True``, ``TransformerLayer`` module exposes a single fused
parameter for query-key-value. This enables optimizations such as QKV parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`. ``fuse_wgrad_accumulation``.
qk_norm_type: Optional[str], default = None qk_norm_type : Optional[str], default = None
type of normalization to apply to query and key tensors. type of normalization to apply to query and key tensors.
Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied. Options: ``None``, ``'L2Normalization'``, ``'RMSNorm'``, ``'LayerNorm'``. When ``None``, no normalization is applied.
When 'L2Normalization', L2 normalization is applied to query and key tensors. When ``'L2Normalization'``, L2 normalization is applied to query and key tensors.
When 'RMSNorm', RMS normalization is applied to query and key tensors. When ``'RMSNorm'``, RMS normalization is applied to query and key tensors.
When 'LayerNorm', layer normalization is applied to query and key tensors. When ``'LayerNorm'``, layer normalization is applied to query and key tensors.
Normalization is applied after RoPE (if applicable) but before attention computation Normalization is applied after RoPE (if applicable) but before attention computation
when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach when ``qk_norm_before_rope`` is ``False``. This follows the e.g. Llama4 approach
for QK normalization to improve training stability and model performance. for QK normalization to improve training stability and model performance.
qk_norm_eps: float, default = 1e-6 qk_norm_eps : float, default = 1e-6
epsilon value for normalization of query and key tensors. epsilon value for normalization of query and key tensors.
Only used when `qk_norm_type` is not None. Only used when ``qk_norm_type`` is not ``None``.
qk_norm_before_rope: bool, default = `False` qk_norm_before_rope : bool, default = False
if set to `True`, query and key normalization is applied before rotary position if set to ``True``, query and key normalization is applied before rotary position
embedding. When `False` (default), normalization is applied after RoPE. embedding. When ``False`` (default), normalization is applied after RoPE.
This parameter allows supporting different architectural variants that apply This parameter allows supporting different architectural variants that apply
QK normalization at different points. QK normalization at different points.
seq_length: Optional[int], default = `None` seq_length : Optional[int], default = None
sequence length of input samples. Needed for JIT Warmup, a technique where jit sequence length of input samples. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are used for fused functions are warmed up before training to ensure same kernels are used for
forward propagation and activation recompute phase. forward propagation and activation recompute phase.
micro_batch_size: Optional[int], default = `None` micro_batch_size : Optional[int], default = None
batch size per training step. Needed for JIT Warmup, a technique where jit batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are fused functions are warmed up before training to ensure same kernels are
used for forward propagation and activation recompute phase. used for forward propagation and activation recompute phase.
...@@ -535,7 +549,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -535,7 +549,7 @@ class MultiheadAttention(torch.nn.Module):
Parameters Parameters
---------- ----------
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = None
tensor parallel process group. tensor parallel process group.
""" """
self.tp_group = tp_group self.tp_group = tp_group
...@@ -555,25 +569,26 @@ class MultiheadAttention(torch.nn.Module): ...@@ -555,25 +569,26 @@ class MultiheadAttention(torch.nn.Module):
---------- ----------
cp_group : Union[ProcessGroup, List[ProcessGroup]] cp_group : Union[ProcessGroup, List[ProcessGroup]]
context parallel process group. context parallel process group.
ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a". ``ProcessGroup`` is for :attr:`cp_comm_type` of ``"p2p"``, ``"all_gather"``, and ``"a2a"``.
List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0] ``List[ProcessGroup]`` is for :attr:`cp_comm_type` of ``"a2a+p2p"``, where :attr:`cp_group[0]`
and cp_group[1] are for a2a and p2p communications respectively. and :attr:`cp_group[1]` are for ``"a2a"`` and ``"p2p"`` communications respectively.
cp_global_ranks : List[int] cp_global_ranks : List[int]
list of global ranks in the context group. list of global ranks in the context group.
cp_stream : torch.cuda.Stream cp_stream : torch.cuda.Stream
cuda stream for context parallel execution. cuda stream for context parallel execution.
cp_comm_type : str, default = `p2p` cp_comm_type : str, default = "p2p"
inter-gpu communication type for context parallelism. inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather" or "a2a", "a2a+p2p". Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``.
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute. - ``"p2p"``: Exchange KV chunks with P2P communications in ring topology.
"all_gather": All-gather to get full sequence of KV before attention. P2P is async and can be overlapped with attention compute.
The all-gather is not async, and cannot be overlapped. - ``"all_gather"``: All-gather to get full sequence of KV before attention.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP The all-gather is not async, and cannot be overlapped.
group, and gather to get full sequence of QKV. - ``"a2a"``: Like DeepSpeed Ulysses, scatter attention heads across the CP
"a2a+p2p": hierarchical CP implementation. First applying a2a to QKV group, and gather to get full sequence of QKV.
across each CP sub-group (e.g., via NVLink), then exchanging KV with - ``"a2a+p2p"``: hierarchical CP implementation. First applying a2a to QKV
p2p between sub-groups (e.g., via IBLink). across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink).
""" """
if isinstance(cp_group, dist_group_type): if isinstance(cp_group, dist_group_type):
self.cp_size = get_distributed_world_size(cp_group) self.cp_size = get_distributed_world_size(cp_group)
...@@ -622,39 +637,39 @@ class MultiheadAttention(torch.nn.Module): ...@@ -622,39 +637,39 @@ class MultiheadAttention(torch.nn.Module):
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
pad_between_seqs: Optional[bool] = None, pad_between_seqs: Optional[bool] = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
""" r"""
Forward propagation for MultiheadAttention layer. Forward propagation for MultiheadAttention layer.
.. note:: .. note::
Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type` Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
includes `"padding"` or `"arbitrary"`. includes ``"padding"`` or ``"arbitrary"``.
Parameters Parameters
---------- ----------
hidden_states : torch.Tensor hidden_states : torch.Tensor
Input tensor. Input tensor.
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensor(s) used to mask out attention softmax input. default = None. Boolean tensor(s) used to mask out attention softmax input.
It should be `None` for causal masks and "`no_mask`". For padding masks, it should be It should be ``None`` for causal masks and ``"no_mask"``. For padding masks, it should be
a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of a single tensor of ``[batch_size, 1, 1, seqlen_q]`` for self-attention, and a tuple of
two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] two tensors of shapes ``[batch_size, 1, 1, seqlen_q]`` and ``[batch_size, 1, 1, seqlen_kv]``
for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to for cross-attention. For ``"arbitrary"`` mask, it should be of a shape broadcastable to
[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means ``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]``. A ``True`` value means
the corresponding position is masked out and a `False` means that position the corresponding position is masked out and a ``False`` means that position
is allowed to participate in attention. is allowed to participate in attention.
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right', attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
'padding_causal_bottom_right','arbitrary'}, 'padding_causal_bottom_right','arbitrary'},
default = `None` default = None
type of attention mask passed into softmax operation. By default, type of attention mask passed into softmax operation. By default,
causal masks are aligned to the top left corner of the softmax matrix. causal masks are aligned to the top left corner of the softmax matrix.
When "`bottom_right`" is specified in the mask type, causal masks are When ``"bottom_right"`` is specified in the mask type, causal masks are
aligned to the bottom right corner. aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None` window_size: Optional[Tuple[int, int]], default = None
sliding window size for local attention. sliding window size for local attention.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = None
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`. ``layer_type="decoder"``.
is_first_microbatch : {True, False, None}, default = None is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split pipeline parallelism a minibatch of data is further split
...@@ -668,46 +683,46 @@ class MultiheadAttention(torch.nn.Module): ...@@ -668,46 +683,46 @@ class MultiheadAttention(torch.nn.Module):
* it also allows skipping gradient accumulation during the * it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
checkpoint_core_attention: bool, default = `False` checkpoint_core_attention: bool, default = False
If true, forward activations for core attention are recomputed If ``True``, forward activations for core attention are recomputed
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until otherwise be occupied to store the forward activations until
backprop. backprop.
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = None
Embeddings for query and key tensors for applying rotary position Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied. embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias` core_attention_bias_type: str, default = "no_bias"
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`} Bias type, {``"no_bias"``, ``"pre_scale_bias"``, ``"post_scale_bias"``, ``"alibi"``}
core_attention_bias: Optional[torch.Tensor], default = `None` core_attention_bias: Optional[torch.Tensor], default = None
Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv]. Bias tensor for :math:`Q \cdot K^T`, shape ``[1, num_head, max_seqlen_q, max_seqlen_kv]``.
It should be 'None' for 'no_bias' and 'alibi' bias types. It should be ``None`` for ``"no_bias"`` and ``"alibi"`` bias types.
alibi_slopes: Optional[torch.Tensor], default = `None` alibi_slopes: Optional[torch.Tensor], default = None
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. ALiBi slopes in FP32 and shape ``[nheads]`` or ``[batch_size, nheads]``.
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) It adds a bias of ``(-alibi_slope * (i + seqlen_k - seqlen_q - j))``
to the attention score of query i and key j. to the attention score of query i and key j.
cu_seqlens_q: Optional[torch.Tensor], default = `None` cu_seqlens_q: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, Cumulative sum of sequence lengths (without offset) in a batch for ``query_layer``,
with shape [batch_size + 1] and dtype torch.int32. with shape ``[batch_size + 1]`` and dtype torch.int32.
cu_seqlens_kv: Optional[torch.Tensor], default = `None` cu_seqlens_kv: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` Cumulative sum of sequence lengths (without offset) in a batch for ``key_layer``
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. and ``value_layer``, with shape ``[batch_size + 1]`` and dtype torch.int32.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None` cu_seqlens_q_padded: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`, Cumulative sum of sequence lengths (with offset) in a batch for ``query_layer``,
with shape [batch_size + 1] and dtype torch.int32. with shape ``[batch_size + 1]`` and dtype torch.int32.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None` cu_seqlens_kv_padded: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer` Cumulative sum of sequence lengths (with offset) in a batch for ``key_layer``
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. and ``value_layer``, with shape ``[batch_size + 1]`` and dtype torch.int32.
max_seqlen_q: Optional[int], default = `None` max_seqlen_q: Optional[int], default = None
Maximum sequence length in `query_layer`. Maximum sequence length in ``query_layer``.
Calculated from `cu_seqlens_q` if not provided. Calculated from ``cu_seqlens_q`` if not provided.
max_seqlen_kv: Optional[int], default = `None` max_seqlen_kv: Optional[int], default = None
Maximum sequence length in `key_layer` and `value_layer`. Maximum sequence length in ``key_layer`` and ``value_layer``.
Calculated from `cu_seqlens_kv` if not provided. Calculated from ``cu_seqlens_kv`` if not provided.
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = True
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
pad_between_seqs: Optional[bool], default = `None` pad_between_seqs: Optional[bool], default = None
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. If ``None``, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch. If ``True``, there are padding tokens between individual sequences in a packed batch.
""" """
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
......
...@@ -287,16 +287,16 @@ def _apply_rotary_pos_emb_base( ...@@ -287,16 +287,16 @@ def _apply_rotary_pos_emb_base(
Parameters Parameters
---------- ----------
t: torch.Tensor t : torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional
embedding will be applied. embedding will be applied.
freqs: torch.Tensor freqs : torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` or `[s2, b, 1, d2]` Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` or `[s2, b, 1, d2]`
and dtype 'float', with `s2 >= s` and `d2 <= d`. and dtype 'float', with `s2 >= s` and `d2 <= d`.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' tensor_format : {'sbhd', 'bshd'}, default = 'sbhd'
Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
`[seq, bs, ...]`. `[seq, bs, ...]`.
interleaved: bool, default = False interleaved : bool, default = False
Whether to use interleaved rotary position embedding. Whether to use interleaved rotary position embedding.
""" """
# [seq, 1, 1, dim] -> [1, seq, 1, dim] or # [seq, 1, 1, dim] -> [1, seq, 1, dim] or
...@@ -324,7 +324,7 @@ def _get_freqs_on_this_cp_rank( ...@@ -324,7 +324,7 @@ def _get_freqs_on_this_cp_rank(
"""Get the position embedding on the current context parallel rank. """Get the position embedding on the current context parallel rank.
Args: Args:
freqs: torch.Tensor. Positional embedding tensor in shape `[s2, 1, 1, d2]`. freqs: torch.Tensor. Positional embedding tensor of shape `[s2, 1, 1, d2]`.
seqlen: int. Length of the current sequence. seqlen: int. Length of the current sequence.
cp_size: int. Context parallel world size. cp_size: int. Context parallel world size.
cp_rank: int. Context parallel rank. cp_rank: int. Context parallel rank.
...@@ -372,29 +372,29 @@ def apply_rotary_pos_emb( ...@@ -372,29 +372,29 @@ def apply_rotary_pos_emb(
Parameters Parameters
---------- ----------
t: torch.Tensor t : torch.Tensor
Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied. rotary positional embedding will be applied.
freqs: torch.Tensor freqs : torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`. with `s2 >= s` and `d2 <= d`.
start_positions: torch.Tensor, default = None. start_positions : torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset. `start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' tensor_format : {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True. of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
interleaved: bool, default = False interleaved : bool, default = False
Whether to use interleaved rotary position embedding. Whether to use interleaved rotary position embedding.
fused: bool, default = False fused : bool, default = False
Whether to use a fused applying RoPE implementation. Whether to use a fused applying RoPE implementation.
cu_seqlens: torch.Tensor, default = None. cu_seqlens : torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'. dtype torch.int32. Only valid when `tensor_format` is 'thd'.
Should be `cu_seqlens_padded` when cp_size > 1. Should be `cu_seqlens_padded` when cp_size > 1.
cp_size: int, default = 1. cp_size : int, default = 1.
Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True. Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
cp_rank: int, default = 0. cp_rank : int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
""" """
assert ( assert (
...@@ -492,32 +492,32 @@ def apply_fused_qkv_rotary_pos_emb( ...@@ -492,32 +492,32 @@ def apply_fused_qkv_rotary_pos_emb(
Parameters Parameters
---------- ----------
qkv: torch.Tensor qkv : torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which
rotary positional embedding will be applied. This tensor has q, k, v concatenated rotary positional embedding will be applied. This tensor has q, k, v concatenated
along the last dimension. along the last dimension.
q_freqs: torch.Tensor q_freqs : torch.Tensor
Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float', Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`. with `s2 >= s` and `d2 <= d`.
k_freqs: torch.Tensor k_freqs : torch.Tensor
Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float', Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`. with `s2 >= s` and `d2 <= d`.
qkv_split_arg_list: List[int] qkv_split_arg_list : List[int]
List of integers that specify the split of the qkv tensor. The list should have 3 elements, List of integers that specify the split of the qkv tensor. The list should have 3 elements,
the first element is the number of elements in the q tensor, the second element is the number the first element is the number of elements in the q tensor, the second element is the number
of elements in the k tensor, and the third element is the number of elements in the v tensor. of elements in the k tensor, and the third element is the number of elements in the v tensor.
The sum of the elements in the list should be equal to the last dimension of the qkv tensor. The sum of the elements in the list should be equal to the last dimension of the qkv tensor.
start_positions: torch.Tensor, default = None. start_positions : torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset. `start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' tensor_format : {'sbhd', 'bshd'}, default = 'sbhd'
is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is
of shape `[seq, bs, ...]`. of shape `[seq, bs, ...]`.
interleaved: bool, default = False interleaved : bool, default = False
Whether to use interleaved rotary position embedding. Whether to use interleaved rotary position embedding.
cp_size: int, default = 1. cp_size : int, default = 1.
Context parallel world size. Context parallel world size.
cp_rank: int, default = 0. cp_rank : int, default = 0.
Context parallel rank. Context parallel rank.
""" """
......
...@@ -146,89 +146,89 @@ def fused_attn_fwd( ...@@ -146,89 +146,89 @@ def fused_attn_fwd(
Parameters Parameters
---------- ----------
is_training: bool is_training : bool
if True, runs training and produces auxiliary tensors aux_ctx_tensors if True, runs training and produces auxiliary tensors aux_ctx_tensors
for the backward; if False, runs inference and doesn't produce aux_ctx_tensors for the backward; if False, runs inference and doesn't produce aux_ctx_tensors
max_seqlen_q: int max_seqlen_q : int
max sequence length for Q, used for padding; max sequence length for Q, used for padding;
may be larger than max(seqlens_q), may be larger than max(seqlens_q),
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_kv: int max_seqlen_kv : int
max sequence length for K and V, used for padding; max sequence length for K and V, used for padding;
may be larger than max(seqlens_kv), may be larger than max(seqlens_kv),
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
cu_seqlens_q: torch.Tensor cu_seqlens_q : torch.Tensor
cumulative sequence lengths for Q; shape [batch_size + 1] cumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor cu_seqlens_kv : torch.Tensor
cumulative sequence lengths for K and V; shape [batch_size + 1] cumulative sequence lengths for K and V; shape [batch_size + 1]
q: torch.Tensor q : torch.Tensor
input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details) input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details)
k: torch.Tensor k : torch.Tensor
input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details) input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details)
v: torch.Tensor v : torch.Tensor
input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details) input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details)
fake_dtype: tex.DType fake_dtype : tex.DType
data type of Q, K and V - in case of high precision, fake dtype in case of FP8; data type of Q, K and V - in case of high precision, fake dtype in case of FP8;
in torch.dtype in torch.dtype
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend : tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
attn_bias: torch.Tensor, default = None attn_bias : torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
cu_seqlens_q_padded: torch.Tensor, default = None cu_seqlens_q_padded : torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for Q; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None cu_seqlens_kv_padded : torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
page_table_k: torch.Tensor, default = None page_table_k : torch.Tensor, default = None
page table for K cache; shape [batch_size, max_pages_per_seq_k] page table for K cache; shape [batch_size, max_pages_per_seq_k]
page_table_v: torch.Tensor, default = None page_table_v : torch.Tensor, default = None
page table for V cache; shape [batch_size, max_pages_per_seq_v] page table for V cache; shape [batch_size, max_pages_per_seq_v]
s_quantizer: Quantizer, default = None s_quantizer : Quantizer, default = None
Quantizer object for the intermediate value S. Quantizer object for the intermediate value S.
o_quantizer: Quantizer, default = None o_quantizer : Quantizer, default = None
Quantizer object for the output of the attention. Quantizer object for the output of the attention.
attn_scale: float, default = None attn_scale : float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM; if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim_qk) as the default if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0 dropout : float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output; dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False dropout must be 0.0 if is_training is False
fast_zero_fill: bool, default = True fast_zero_fill : bool, default = True
if True, initializes the output tensor O to zero using the fast filling method; if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "sbh3d" qkv_layout : str, default = "sbh3d"
layout of Q, K and V; layout of Q, K and V;
{"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"}
attn_bias_type: str, default = "no_bias" attn_bias_type : str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type : str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla" softmax_type : str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"} type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1) window_size : Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. window and causal mask specifically.
rng_gen: torch.Generator, default = None rng_gen : torch.Generator, default = None
random number generator; random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
softmax_offset: torch.Tensor, default = None softmax_offset : torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1]. softmax offset tensor of shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details. See softmax_type in DotProductAttention for details.
return_max_logit: bool, default = False return_max_logit : bool, default = False
whether to return the maximum attention score whether to return the maximum attention score
cuda_graph: bool, default = False cuda_graph : bool, default = False
whether or not cuda graph capture is enabled. whether or not cuda graph capture is enabled.
Returns Returns
---------- ----------
o: torch.Tensor o : torch.Tensor
output tensor O, of the attention calculation; same data type as Q, K and V; output tensor O, of the attention calculation; same data type as Q, K and V;
same shape as Q same shape as Q
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors : List[torch.Tensor]
auxiliary output tensors used for the backward; auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state] if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state]
if is_training is False, aux_ctx_tensors = None if is_training is False, aux_ctx_tensors = None
...@@ -252,7 +252,7 @@ def fused_attn_fwd( ...@@ -252,7 +252,7 @@ def fused_attn_fwd(
rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen
state of the random number generator; state of the random number generator;
[seed, offset], dtype uint64 [seed, offset], dtype uint64
max_logit: if return_max_logit = True, shape [h] and same data type as O; otherwise None max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None
""" """
if attn_scale is None: if attn_scale is None:
...@@ -377,89 +377,89 @@ def fused_attn_bwd( ...@@ -377,89 +377,89 @@ def fused_attn_bwd(
Parameters Parameters
---------- ----------
max_seqlen_q: int max_seqlen_q : int
max sequence length for Q, used for padding; may be larger than max(seqlens_q), max sequence length for Q, used for padding; may be larger than max(seqlens_q),
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_kv: int max_seqlen_kv : int
max sequence length for K and V, used for padding; max sequence length for K and V, used for padding;
may be larger than max(seqlens_kv), may be larger than max(seqlens_kv),
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
cu_seqlens_q: torch.Tensor cu_seqlens_q : torch.Tensor
cumulative sequence lengths for Q; shape [batch_size + 1] cumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor cu_seqlens_kv : torch.Tensor
cumulative sequence lengths for K and V; shape [batch_size + 1] cumulative sequence lengths for K and V; shape [batch_size + 1]
q: torch.Tensor q : torch.Tensor
input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details) input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details)
k: torch.Tensor k : torch.Tensor
input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details) input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details)
v: torch.Tensor v : torch.Tensor
input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details) input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details)
o: torch.Tensor o : torch.Tensor
input tensor O (output of forward); same data type as Q, K and V; input tensor O (output of forward); same data type as Q, K and V;
same shape as Q same shape as Q
d_o: torch.Tensor d_o : torch.Tensor
input tensor dO (gradient of O); same data type as Q, K and V; input tensor dO (gradient of O); same data type as Q, K and V;
same shape as Q same shape as Q
fake_dtype: tex.DType fake_dtype : tex.DType
data type of Q, K and V - in case of high precision, fake dtype in case of FP8; data type of Q, K and V - in case of high precision, fake dtype in case of FP8;
in torch.dtype in torch.dtype
dqkv_dtype: tex.DType dqkv_dtype : tex.DType
data type of dQ, dK and dV; in tex.DType, not torch.dtype data type of dQ, dK and dV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors : List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True, auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend : tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
cu_seqlens_q_padded: torch.Tensor, default = None cu_seqlens_q_padded : torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for Q; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None cu_seqlens_kv_padded : torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
s_quantizer: Quantizer, default = None s_quantizer : Quantizer, default = None
Quantizer object for the intermediate value S. Quantizer object for the intermediate value S.
dp_quantizer: Quantizer, default = None dp_quantizer : Quantizer, default = None
Quantizer object for the intermediate value dP. Quantizer object for the intermediate value dP.
dqkv_quantizer: Quantizer, default = None dqkv_quantizer : Quantizer, default = None
Quantizer object for the output values of the fused_attn_bwd. Quantizer object for the output values of the fused_attn_bwd.
dropout: float, default = 0.0 dropout : float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output; dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False dropout must be 0.0 if is_training is False
fast_zero_fill: bool, default = True fast_zero_fill : bool, default = True
if True, initializes the output tensor O to zero using the fast filling method; if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "sbh3d" qkv_layout : str, default = "sbh3d"
layout of Q, K and V; layout of Q, K and V;
{"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"}
attn_bias_type: str, default = "no_bias" attn_bias_type : str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type : str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla" softmax_type : str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"} type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1) window_size : Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. window and causal mask specifically.
deterministic: bool, default = False deterministic : bool, default = False
whether to execute the backward pass with deterministic behaviours. whether to execute the backward pass with deterministic behaviours.
cuda_graph: bool, default = False cuda_graph : bool, default = False
whether or not cuda graph capture is enabled. whether or not cuda graph capture is enabled.
Returns Returns
---------- ----------
d_q: torch.Tensor d_q : torch.Tensor
gradient tensor of Q; same data type and shape as Q gradient tensor of Q; same data type and shape as Q
d_k: torch.Tensor d_k : torch.Tensor
gradient tensor of K; same data type and shape as K gradient tensor of K; same data type and shape as K
d_v: torch.Tensor d_v : torch.Tensor
gradient tensor of V; same data type and shape as V gradient tensor of V; same data type and shape as V
d_bias: torch.Tensor, optional d_bias : torch.Tensor, optional
gradient tensor of Bias when attn_bias_type is "pre_scale_bias" gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias or "post_scale_bias"; same data type and shape as Bias
d_softmax_offset: torch.Tensor, optional d_softmax_offset : torch.Tensor, optional
gradient tensor of softmax offset in shape [1, h_q, 1, 1]. gradient tensor of softmax offset of shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details. See softmax_type in DotProductAttention for details.
""" """
if attn_scale is None: if attn_scale is None:
......
...@@ -657,60 +657,64 @@ def get_cpu_offload_context( ...@@ -657,60 +657,64 @@ def get_cpu_offload_context(
Parameters Parameters
---------- ----------
enabled: bool, default = `False` enabled : bool, default = False
When set to True, CPU Offloading functionality is enabled. When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1 num_layers : int, default = 1
Determines the number of layers Determines the number of layers
you want to offload activations/weights for. you want to offload activations/weights for.
model_layers: int, default = 1 model_layers : int, default = 1
Number of layers in the model that will be used under this context. Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True` offload_activations : bool, default = True
Deprecated. Deprecated.
offload_weights: bool, default = `True` offload_weights : bool, default = True
Deprecated. Deprecated.
double_buffering: bool, default = `False` double_buffering : bool, default = False
Deprecated. Deprecated.
retain_pinned_cpu_buffers: bool, default = `False` retain_pinned_cpu_buffers : bool, default = False
If True, the pinned CPU buffers are retained after offloading If True, the pinned CPU buffers are retained after offloading
and reused for the next iteration. It is useful for cuda graphs capture. and reused for the next iteration. It is useful for cuda graphs capture.
manual_synchronization: bool, default = `False` manual_synchronization : bool, default = False
If True, the synchronization is done manually by the user. If True, the synchronization is done manually by the user.
Additional argument manual_controller is returned. See more in manual control section. Additional argument manual_controller is returned. See more in manual control section.
offload_stream: torch.cuda.Stream, default = `None` offload_stream : torch.cuda.Stream, default = None
If provided, the offload stream is used for offloading and reloading. If provided, the offload stream is used for offloading and reloading.
Otherwise, a new stream is allocated internally. It can be other than None Otherwise, a new stream is allocated internally. It can be other than None
only if manual_synchronization is True. only if manual_synchronization is True.
Manual synchronization Notes
---------- -----
**Manual synchronization:**
By default, layers are offloaded/reloaded asynchronously By default, layers are offloaded/reloaded asynchronously
with respect to the current forward/backward stream with predefined synchronization, with respect to the current forward/backward stream with predefined synchronization,
to ensure that activation memory usage is equal to to ensure that activation memory usage is equal to
`(num_layers - num_offloaded_layers) * T`, where `T` is the memory footprint of a layer. ``(num_layers - num_offloaded_layers) * T``, where ``T`` is the memory footprint of a layer.
For more control over the offloading and reloading process, you can set `manual_synchronization=True`. For more control over the offloading and reloading process, you can set ``manual_synchronization=True``.
In this case, an additional argument, `manual_controller`, is returned. In this case, an additional argument, ``manual_controller``, is returned.
The `manual_controller` provides the following methods: The ``manual_controller`` provides the following methods:
- `start_offload_layer(layer_id: int)` - ``start_offload_layer(layer_id: int)``
- `release_activation_forward_gpu_memory(layer_id: int)` - ``release_activation_forward_gpu_memory(layer_id: int)``
- `start_reload_layer(layer_id: int)` - ``start_reload_layer(layer_id: int)``
If none of these methods are invoked for a given layer, that layer will not be offloaded or reloaded. If none of these methods are invoked for a given layer, that layer will not be offloaded or reloaded.
If `start_offload_layer()` is called for a layer, offload copies for that layer begin asynchronously on the offload stream. If ``start_offload_layer()`` is called for a layer, offload copies for that layer begin asynchronously on the offload stream.
Since GPU activations must be kept in memory until the copy is finished, pointers to all activations are stored. Since GPU activations must be kept in memory until the copy is finished, pointers to all activations are stored.
To release this memory, you need to call `release_activation_forward_gpu_memory(layer_id)`. To release this memory, you need to call ``release_activation_forward_gpu_memory(layer_id)``.
This method makes the current stream wait for an event recorded on the offload stream after all tensors from the layer have been offloaded. This method makes the current stream wait for an event recorded on the offload stream after all tensors from the layer have been offloaded.
The `start_reload_layer()` method is used to start reloading a layer. The ``start_reload_layer()`` method is used to start reloading a layer.
Each tensor reload is awaited to finish before `tensor_pop()` for that tensor is called on the current stream. Each tensor reload is awaited to finish before ``tensor_pop()`` for that tensor is called on the current stream.
You can provide an `offload_stream` to be used for offload and reload operations. You can provide an ``offload_stream`` to be used for offload and reload operations.
This allows for more detailed synchronization, such as delaying the start of offloading. This allows for more detailed synchronization, such as delaying the start of offloading.
Example: **Example:**
.. code-block:: python .. code-block:: python
offload_stream = torch.cuda.Stream() offload_stream = torch.cuda.Stream()
cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context( cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context(
enabled=True, model_layers=num_layers, manual_synchronization=True, offload_stream=offload_stream) enabled=True, model_layers=num_layers, manual_synchronization=True, offload_stream=offload_stream)
...@@ -732,10 +736,10 @@ def get_cpu_offload_context( ...@@ -732,10 +736,10 @@ def get_cpu_offload_context(
for i in range(num_layers): for i in range(num_layers):
out[i].sum().backward() out[i].sum().backward()
V1 code path **V1 code path:**
----------
If you want to use the v1 code path for offloading, If you want to use the v1 code path for offloading,
please set the environment variable NVTE_CPU_OFFLOAD_V1 to 1. please set the environment variable ``NVTE_CPU_OFFLOAD_V1`` to 1.
""" """
if NVTE_CPU_OFFLOAD_V1: if NVTE_CPU_OFFLOAD_V1:
......
...@@ -685,18 +685,18 @@ def get_cpu_offload_context( ...@@ -685,18 +685,18 @@ def get_cpu_offload_context(
Parameters Parameters
---------- ----------
enabled: bool, default = `False` enabled : bool, default = `False`
When set to True, CPU Offloading functionality is enabled. When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1 num_layers : int, default = 1
Determines the number of transformer layers Determines the number of transformer layers
you want to offload activations/weights for. you want to offload activations/weights for.
model_layers: int, default = 1 model_layers : int, default = 1
Number of layers in the model that will be used under this context. Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True` offload_activations : bool, default = `True`
When set to `True`, offloads the activations for the TE layer. When set to `True`, offloads the activations for the TE layer.
offload_weights: bool, default = `True` offload_weights : bool, default = `True`
When set to `True`, offloads the weights for the TE layer. When set to `True`, offloads the weights for the TE layer.
double_buffering: bool, default = `False` double_buffering : bool, default = `False`
When set to `True`, uses double buffering for offloading. When set to `True`, uses double buffering for offloading.
""" """
......
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
"""Cross Entropy Loss API""" """Cross Entropy Loss API"""
from typing import Optional
import warnings
import torch import torch
import transformer_engine.pytorch.triton.cross_entropy as triton_cross_entropy import transformer_engine.pytorch.triton.cross_entropy as triton_cross_entropy
...@@ -23,7 +26,7 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -23,7 +26,7 @@ class CrossEntropyFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
_input, inp,
target, target,
label_smoothing=0.0, label_smoothing=0.0,
reduce_loss=False, reduce_loss=False,
...@@ -37,7 +40,7 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -37,7 +40,7 @@ class CrossEntropyFunction(torch.autograd.Function):
Parameters: Parameters:
ctx : The context object. ctx : The context object.
_input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size. inp (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size.
target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1]. target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1].
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension.
...@@ -47,8 +50,8 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -47,8 +50,8 @@ class CrossEntropyFunction(torch.autograd.Function):
Returns: Returns:
tensor: The computed loss. tensor: The computed loss.
""" """
loss, _input = triton_cross_entropy.cross_entropy_forward( loss, inp = triton_cross_entropy.cross_entropy_forward(
_input, inp,
target, target,
label_smoothing, label_smoothing,
reduce_loss, reduce_loss,
...@@ -56,7 +59,7 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -56,7 +59,7 @@ class CrossEntropyFunction(torch.autograd.Function):
ignore_idx, ignore_idx,
) )
ctx.save_for_backward(_input.detach()) ctx.save_for_backward(inp.detach())
ctx.is_cg_capturable = is_cg_capturable ctx.is_cg_capturable = is_cg_capturable
return loss return loss
...@@ -72,12 +75,10 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -72,12 +75,10 @@ class CrossEntropyFunction(torch.autograd.Function):
Returns: Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
""" """
(_input,) = ctx.saved_tensors (inp,) = ctx.saved_tensors
_input = triton_cross_entropy.cross_entropy_backward( inp = triton_cross_entropy.cross_entropy_backward(inp, grad_output, ctx.is_cg_capturable)
_input, grad_output, ctx.is_cg_capturable
)
return ( return (
_input, inp,
None, None,
None, None,
None, None,
...@@ -87,4 +88,65 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -87,4 +88,65 @@ class CrossEntropyFunction(torch.autograd.Function):
) )
parallel_cross_entropy = CrossEntropyFunction.apply def parallel_cross_entropy(
inp: torch.Tensor,
target: torch.Tensor,
label_smoothing: float = 0.0,
reduce_loss: bool = False,
dist_process_group: Optional[torch.distributed.ProcessGroup] = None,
ignore_idx: int = -100,
is_cg_capturable: bool = False,
*,
_input: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Cross Entropy loss with optional distributed reduction.
The input tensor can be in BF16/FP32, the loss and gradient calculation happens in
FP32 only. The returned loss is always in FP32, the input gradients are upcasted
to the datatype of the input.
If ``dist_process_group`` is passed for distributed loss calculation, the input to each
distributed rank should be ``(*, V/world_size)``. Note that each of the ranks should
get equal shards along the V dimension.
Parameters
----------
inp : torch.Tensor
The input tensor of shape ``(B, SQ, V)`` or ``(SQ, B, V)`` where B is batch size,
SQ is sequence length, V is vocab size.
target : torch.Tensor
The target tensor of shape ``(B, SQ)`` or ``(SQ, B)`` where each value is in ``[0, V-1]``.
label_smoothing : float, default = 0.0
The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduce_loss : bool, default = False
If True, returns the averaged loss across the B*SQ dimension.
dist_process_group : torch.distributed.ProcessGroup, default = None
The distributed process group the loss computation is split across, None if on 1 device.
ignore_idx : int, default = -100
The index for which loss and gradients are made to zero.
is_cg_capturable : bool, default = False
Whether the operation is CUDA graph capturable.
Returns
-------
torch.Tensor
The computed loss.
"""
# Handle backward compatibility with _input parameter
if _input is not None:
warnings.warn(
"The '_input' parameter is deprecated. Please use 'inp' instead.",
FutureWarning,
)
inp = _input
return CrossEntropyFunction.apply(
inp,
target,
label_smoothing,
reduce_loss,
dist_process_group,
ignore_idx,
is_cg_capturable,
)
...@@ -30,7 +30,7 @@ except ImportError: ...@@ -30,7 +30,7 @@ except ImportError:
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv
from . import torch_version from .torch_version import torch_version
from .utils import ( from .utils import (
is_non_tn_fp8_gemm_supported, is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data, safely_set_viewless_tensor_data,
...@@ -642,18 +642,18 @@ def checkpoint( ...@@ -642,18 +642,18 @@ def checkpoint(
Parameters Parameters
---------- ----------
function: Callable function : Callable
pytorch module used to run the forward and backward passes using pytorch module used to run the forward and backward passes using
the specified :attr:`args` and :attr:`kwargs`. the specified :attr:`args` and :attr:`kwargs`.
distribute_saved_activations: bool, default = False distribute_saved_activations : bool, default = False
if set to `True` and `use_reentrant=True`, first tensor argument is distributed if set to ``True`` and ``use_reentrant=True``, first tensor argument is distributed
across the specified tensor parallel group (`tp_group`) before saving it for the across the specified tensor parallel group (``tp_group``) before saving it for the
backward pass. This has no effect when `use_reentrant=False`. backward pass. This has no effect when ``use_reentrant=False``.
get_rng_state_tracker: `Callable`, default = None get_rng_state_tracker : Callable, default = None
python callable which returns an instance of :func:`CudaRNGStatesTracker`. python callable which returns an instance of :class:`CudaRNGStatesTracker`.
tp_group : ProcessGroup, default = None tp_group : ProcessGroup, default = None
tensor parallel process group. Used only when `distribute_saved_activations=True` tensor parallel process group. Used only when ``distribute_saved_activations=True``
and `use_reentrant=True`. If `None`, it falls back to the default group. and ``use_reentrant=True``. If ``None``, it falls back to the default group.
use_reentrant : bool, default = True use_reentrant : bool, default = True
perform checkpointing in reentrant mode. perform checkpointing in reentrant mode.
args : tuple args : tuple
...@@ -778,8 +778,8 @@ class CudaRNGStatesTracker: ...@@ -778,8 +778,8 @@ class CudaRNGStatesTracker:
For model parallelism, multiple RNG states need to simultaneously exist in order For model parallelism, multiple RNG states need to simultaneously exist in order
to execute operations in or out of the model parallel region. This class keeps to execute operations in or out of the model parallel region. This class keeps
track of the various RNG states and provides utility methods to maintain them and track of the various RNG states and provides utility methods to maintain them and
execute parts of the model under a given RNG setting. Using the `add` method, a execute parts of the model under a given RNG setting. Using the :meth:`add` method, a
cuda rng state is initialized based on the input `seed` and is assigned to `name`. cuda rng state is initialized based on the input ``seed`` and is assigned to ``name``.
Later, by forking the rng state, we can perform operations and return to our starting Later, by forking the rng state, we can perform operations and return to our starting
cuda state. cuda state.
""" """
...@@ -812,7 +812,9 @@ class CudaRNGStatesTracker: ...@@ -812,7 +812,9 @@ class CudaRNGStatesTracker:
Set the rng states. For efficiency purposes, we do not Set the rng states. For efficiency purposes, we do not
check the size of seed for compatibility. check the size of seed for compatibility.
states: Dict[str, torch.Tensor] Parameters
----------
states : Dict[str, torch.Tensor]
A mapping from string names to RNG states. A mapping from string names to RNG states.
""" """
self.states_ = states self.states_ = states
...@@ -821,9 +823,11 @@ class CudaRNGStatesTracker: ...@@ -821,9 +823,11 @@ class CudaRNGStatesTracker:
""" """
Adds a new RNG state. Adds a new RNG state.
name: str Parameters
----------
name : str
string identifier for the RNG state. string identifier for the RNG state.
seed: int seed : int
PyTorch seed for the RNG state. PyTorch seed for the RNG state.
""" """
# Check seed is not already used. # Check seed is not already used.
...@@ -857,7 +861,9 @@ class CudaRNGStatesTracker: ...@@ -857,7 +861,9 @@ class CudaRNGStatesTracker:
Fork the cuda rng state, perform operations, and exit with Fork the cuda rng state, perform operations, and exit with
the original state. the original state.
name: str Parameters
----------
name : str
string identifier for the RNG state. string identifier for the RNG state.
""" """
# Check if we have added the state # Check if we have added the state
...@@ -2003,7 +2009,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ...@@ -2003,7 +2009,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
Parameters Parameters
---------- ----------
fsdp_root: torch.nn.Module fsdp_root : torch.nn.Module
FSDP-wrapped root module that may contain FSDP-wrapped TE modules. FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
""" """
assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped." assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped."
......
...@@ -28,7 +28,7 @@ def onnx_export(enabled: bool = False) -> Generator[None, None, None]: ...@@ -28,7 +28,7 @@ def onnx_export(enabled: bool = False) -> Generator[None, None, None]:
Parameters Parameters
---------- ----------
enabled: bool, default = `False` enabled : bool, default = False
whether or not to enable export whether or not to enable export
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment