Unverified Commit df39a7c2 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

Docs fix (#2301)



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lines lenght
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* subtitle --- fix in many files:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cross entropy _input -> input rename
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cross entropy _input -> input rename
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* a lot of small fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* torch_version() change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add missing module and fix warnings
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* removed training whitespace:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update docs/api/pytorch.rst
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* Fix import
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix more imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix NumPy docstring parameter spacing and indentation

- Standardize parameter documentation to use 'param : type' format (space before and after colon) per NumPy style guide
- Fix inconsistent indentation in cpu_offload.py docstring
- Modified 51 Python files across transformer_engine/pytorch
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ca468ebe
......@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import (
Quantizer,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
QuantizeLayout,
)
......
......@@ -39,12 +39,12 @@ from ..quantize import (
Quantizer,
GroupedQuantizer,
QuantizerSet,
QuantizeLayout,
noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
get_quantize_config_with_recipe,
get_global_quantize_recipe,
QuantizeLayout,
)
from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import (
......
......@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
transpose. Note, transpose_axis should be greater than static_axis_boundary
examples:
X in shape (dim0, dim1, dim2, dim3, dim4)
X of shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1)
......
......@@ -35,9 +35,9 @@ from ..sharding import (
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import (
Quantizer,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
QuantizeLayout,
)
......
......@@ -40,11 +40,11 @@ from ..quantize import (
GroupedScaledTensor1x,
Quantizer,
GroupedQuantizer,
QuantizeLayout,
ScalingMode,
compute_scale_from_amax,
NoScaleTensor,
get_rht_matrix,
QuantizeLayout,
)
......
......@@ -21,12 +21,12 @@ from .quantize import (
ScaledTensorFactory,
ScaledTensor,
ScalingMode,
QuantizeLayout,
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported,
TensorUsage,
QuantizeLayout,
)
......
......@@ -279,26 +279,26 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization.
zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
If set to ``True``, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
The default of `scale_init` will also be changed. See `scale_init`.
This parameter is only applicable for ``'layernorm'``.
The default of ``scale_init`` will also be changed. See ``scale_init``.
scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
Otherwise, scale_init is `flax.linen.initializers.ones`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
bias_init : Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`,
only used when :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
only used when :attr:`layernorm_type='layernorm'`.
......@@ -424,15 +424,15 @@ class DenseGeneral(TransformerEngineBase):
kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes : Tuple[str, ...], default = ()
The name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False
Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias.
If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`.
......@@ -443,12 +443,12 @@ class DenseGeneral(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
:math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint.
Optimization parameters
......@@ -597,48 +597,48 @@ class LayerNormDenseGeneral(TransformerEngineBase):
epsilon : float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
If set to ``True``, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
The default of `scale_init` will also be changed. See `scale_init`
This parameter is only applicable for ``'layernorm'``.
The default of ``scale_init`` will also be changed. See ``scale_init``
scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
Otherwise, scale_init is `flax.linen.initializers.ones`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only used when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`,
only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes : Tuple[str, ...], default = ()
The name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False
Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias.
If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
If set ``False``, return ``None`` as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each dense layer.
low_rank_adaptation_dim: int, default = 32
......@@ -646,16 +646,16 @@ class LayerNormDenseGeneral(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
:math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint.
dot_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint.
Optimization parameters
......@@ -887,34 +887,34 @@ class LayerNormMLP(TransformerEngineBase):
epsilon : float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
If set to ``True``, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
The default of `scale_init` will also be changed. See `scale_init`.
This parameter is only applicable for ``'layernorm'``.
The default of ``scale_init`` will also be changed. See ``scale_init``.
scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
Otherwise, scale_init is `flax.linen.initializers.ones`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only used when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`,
only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing the weights of both dense layer transformations.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
The name of axes used to shard the weights with a corresponding mesh for
the weight of the first dense layer transformation.
......@@ -923,10 +923,10 @@ class LayerNormMLP(TransformerEngineBase):
the weight of the second dense layer transformation.
use_bias: bool, default = False
Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias.
If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes_1: Tuple[str, ...], default = ('mlp',)
The name of axes used to shard bias with a corresponding mesh for
the weight of the first dense layer transformation.
......@@ -937,7 +937,7 @@ class LayerNormMLP(TransformerEngineBase):
Only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
If set ``False``, return ``None`` as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('gelu',)
The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer.
......@@ -958,20 +958,20 @@ class LayerNormMLP(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True`.
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
:math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint.
dot_1_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 1st dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint.
dot_2_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 2nd dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint.
ffn1_ckpt_name: str = "ffn1"
Checkpoint name for the output of the first fully-connected layer in the MLP block.
......
......@@ -469,7 +469,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
The hidden dimension of each attention head.
num_attention_heads: int
The number of attention heads.
num_gqa_groups: int, default = `None`
num_gqa_groups: int, default = None
Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
......@@ -482,32 +482,45 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type: str, default = 'causal'
This parameter specifies the type of attention mask to be applied during the softmax
operation.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
Each described below:
* no_mask: No attention mask is applied. This means the attention will consider the
* ``no_mask``: No attention mask is applied. This means the attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
* ``padding``: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
* ``causal``: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
* ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
|
.. note:: THD format only supports 'padding' or 'causal_padding' mask type.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
attn_mask_type mask/sequence_descriptor SWA softmax type
--------------------------------------------------------------------------------------------
|
.. note:: THD format only supports ``'padding'`` or ``'causal_padding'`` mask type.
|
.. table::
:widths: auto
================== ============ ========== ==============================
attn_mask_type mask/sd SWA softmax type
================== ============ ========== ==============================
no_mask None None SCALED
causal None None SCALED_UPPER_TRIANG_MASKED
causal None Yes SCALED_MASKED
padding Required Yes/No SCALED_MASKED
padding_causal Required Yes/No SCALED_MASKED
================== ============ ========== ==============================
where sd stands for sequence_descriptor.
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention.
......@@ -553,22 +566,40 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Sliding window size. The default value is no sliding window.
max_segments_per_seq: Optional[int], default = 1
The maximum number of segments per sequence, also used for THD format (sequence packing).
context_parallel_causal_load_balanced (bool):
context_parallel_causal_load_balanced: bool
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
context_parallel_axis: str
The name of the context parallel axis.
context_parallel_strategy: CPStrategy
The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
context_checkpoint_name: str
The name of the context checkpoint in the forward pass of fused attention.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
Softmax type as described in the paper
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
* ``'vanilla'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
* ``'off-by-one'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
Optimization parameters
-----------------------
......@@ -631,7 +662,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input.
:attr:`True` means to mask out the corresponding values.
Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
Ignored when :attr:`self.attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
bias: jax.numpy.ndarray, default = None
A tensor used to shift attention softmax input.
*:
......@@ -818,7 +849,7 @@ def rotary_pos_emb(
):
"""
Rotary Positional Embedding
x should be in shape of
x should be of shape
[Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
[Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
"""
......@@ -956,7 +987,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
The hidden dimension of each attention head.
num_attention_heads: int
The number of attention heads.
num_gqa_groups: int, default = `None`
num_gqa_groups: int, default = None
Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
......@@ -969,28 +1000,28 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type: str, default = 'causal'
This parameter specifies the type of attention mask to be applied during the softmax
operation.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
Each described below:
* no_mask: No attention mask is applied. This means the attention will consider the
* ``no_mask``: No attention mask is applied. This means the attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
* ``padding``: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
* ``causal``: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
* ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
.. note:: :attr:`mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
Available options: ``{'no_bias', 'pre_scale_bias', 'post_scale_bias'}``.
When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
Where it is ``'post_scale_bias'`` if there is bias. Otherwise ``'no_bias'`` is used.
dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that is used
to generate Dropout masks in the core attention.
......@@ -999,27 +1030,27 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
layernorm_epsilon: float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma: bool, default = False
If set to `True`, the LayerNorm formula changes to
If set to ``True``, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
This parameter is only applicable for ``'layernorm'``.
kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')``
Used for initializing the QKV and output projection weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
use_bias: bool, default = False
Indicate whether or not to enable bias shifting for QKV and output projections.
If set to False, the layer will not learn additive biases.
bias_init: Initializer, default = flax.linen.initializers.zeros
If set to ``False``, the layer will not learn additive biases.
bias_init: Initializer, default = ``flax.linen.initializers.zeros``
Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
input_layernorm: bool, default = True
If set to False, layer normalization to the input is not applied.
If set to ``False``, layer normalization to the input is not applied.
return_layernorm_output: bool, default = False
If set to True, output of layernorm is returned from the forward together with the output
If set to ``True``, output of layernorm is returned from the forward together with the output
of the linear transformation.
Example use case: residual connection for transformer module is taken post layernorm.
enable_rotary_pos_emb: bool, default = False
......@@ -1029,17 +1060,17 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
only used when :attr:`enable_rotary_pos_emb=True`
rotary_pos_emb_group_method: str, default = 'consecutive'
Indicate the method to coupled the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
``['consecutive', 'alternate']``. ``'alternate'`` is to pair index :math:`i` with :math:`i + d/2`
, d is the hidden dimension. ``'consecutive'`` pairs index :math:`i` with :math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']
``['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']``
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
:math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
num_heads: int, default = None
......@@ -1066,8 +1097,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
scale_attn_logits: bool, default = False
Indicate whether to scale attention logits.
If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`,
else :math:`Q*K`
If set to True, :math:`\frac{Q \cdot K^T}{\sqrt{head\_dim}}`,
else :math:`Q \cdot K^T`
scaled_query_init: bool, default = True
Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}`
float32_logits: bool, default = False
......@@ -1078,16 +1109,31 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. Default value is no sliding window.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
Softmax type as described in the paper
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
* ``'vanilla'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
* ``'off-by-one'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
"""
head_dim: int
......@@ -1202,7 +1248,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out the attention softmax input.
:attr:`True` means mask out the corresponding values.
Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'.
Ignored when :attr:`self.attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
bias: jax.numpy.ndarray, default = None
A tensor used to shift the attention softmax input.
*
......@@ -1688,7 +1734,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Intermediate size to which input samples are projected.
num_attention_heads: int, default = 8
Number of attention heads in the transformer layer.
num_gqa_groups: int, default = `None`
num_gqa_groups: int, default = None
Number of GQA groups. When `None` is present, it is equal to num_attention_heads.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
......@@ -1722,31 +1768,31 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
The key in given RNGs via flax.linen.Module.apply that for
generating Dropout masks in the Multi-Head Attention.
mha_kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')``
Used for initializing weights of QKV and Output projection weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
mlp_kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')``
Used for initializing weights of FC1 and FC2 layers.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
mlp_activations: Sequence[str], default = ('gelu', )
The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer.
mlp_activation_params: dict = None
This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment
ClampedSwiglu is the only activation that requires parameters.
This is only used when ``('clamped_silu', 'clamped_linear')`` is in :attr:`mlp_activations`. At the moment
``ClampedSwiglu`` is the only activation that requires parameters.
use_bias: bool, default = False
Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
If set to False, the layer will not learn additive biases.
bias_init: Initializer, default = flax.linen.initializers.zeros
If set to ``False``, the layer will not learn additive biases.
bias_init: Initializer, default = ``flax.linen.initializers.zeros``
Used for initializing bias of QKVO projections,
FC1 and FC2. It is only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
apply_residual_connection_post_layernorm: bool, default = False
If set to True, residual connections are taken from the output
If set to ``True``, residual connections are taken from the output
of layer norm (default is taken from input of layer norm)
output_layernorm: bool, default = False
If set to True, layer normalization is applied on the output side,
If set to ``True``, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation.
float32_attention_logits: bool, default = False
......@@ -1754,43 +1800,43 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
For fused attention backend, the accumulation is always float32 without the perf overhead.
layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
If set to TransformerLayerType.DECODER, an additional cross-attention block
is added after self-attention.this can be used for structures like `T5`
is added after self-attention.this can be used for structures like T5
Transformer in conjunction with the TransformerLayerType.ENCODER option.
self_attn_mask_type: str, default = 'causal'
This parameter specifies the type of attention mask to be applied during the softmax
operation in the self attention.
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}
Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'}.
Each described below:
* no_mask: No attention mask is applied. This means the self attention will consider the
* ``no_mask``: No attention mask is applied. This means the self attention will consider the
full sequence without any restrictions.
* padding: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the
* ``padding``: Indicates the presence of padding at the end of each sequence.
Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the
:attr:`__call__` method to specify the padding positions.
* causal: An upper triangular mask is applied to the softmax inputs,
* ``causal``: An upper triangular mask is applied to the softmax inputs,
ensuring that the prediction for a certain position is only dependent on known outputs
from positions before it.
* causal_padding / padding_causal: A combination of both causal and padding masks.
Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect.
* ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks.
Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect.
.. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
.. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``.
self_attn_bias_type: Optional[str], default = None
Type of the attention bias passed into the self attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
Available options: ``{'no_bias', 'pre_scale_bias', 'post_scale_bias'}``.
When default is present, the type is automatically decided by the MHA's bias parameter.
Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used.
Where it is ``'post_scale_bias'`` if there is bias. Otherwise ``'no_bias'`` is used.
enable_relative_embedding: bool, default = True
Whether to enable relative embedding as shifting of attention logits.
relative_embedding: flax.linen.Module, default = None
The module for relative embedding execution, only used when
:attr:`enable_relative_embedding=True`. Default is None, which will create
:attr:`enable_relative_embedding=True`. Default is ``None``, which will create
an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
Default: RelativePositionBiases( num_buckets=32, max_distance=128,
Default: ``RelativePositionBiases( num_buckets=32, max_distance=128,
num_attention_heads=self.num_attention_heads, dtype=self.dtype,
embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
name='relpos_bias')
name='relpos_bias')``
enable_rotary_pos_emb: bool, default = False
Whether to enable rotary position embedding to projected query and key in MHA.
rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
......@@ -1798,34 +1844,49 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
only used when :attr:`enable_rotary_pos_emb=True`
rotary_pos_emb_group_method: str, default = 'consecutive'
Indicate the method to couple the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`,
where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with
``['consecutive', 'alternate']``. ``'alternate'`` is to pair index :math:`i` with :math:`i + d/2`,
where :math:`d` is the hidden dimension. ``'consecutive'`` pairs index :math:`i` with
:math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
'exclude_output_proj', 'exclude_mlp']
``['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
'exclude_output_proj', 'exclude_mlp']``
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora\_output`. None means no scaling.
:math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. Default value is no sliding window.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Softmax type as described in this paper:
Softmax type as described in the paper
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
* ``'vanilla'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
* ``'off-by-one'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
Only supported for fused attention backend.
Optimization parameters
......@@ -1836,19 +1897,19 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
When > 0.0, applies stochastic depth per sample in the main
path of the residual block.
fuse_qkv_params: bool, default = True
If set to True, `TransformerLayer` module exposes a single fused
If set to ``True``, ``TransformerLayer`` module exposes a single fused
parameter for query-key-value for self-attention and key-value for
cross-attention.
transpose_batch_sequence: bool, default = False
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
and sequence length dimension. if set to ``True``, the input tensors
should be in ``(seqlen, batch, hidden)``, otherwise ``(batch, seqlen, hidden)``.
scale_attn_logits: bool, default = False
Indicate whether to scale attention logits.
if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
else :math:`Q*K`
scaled_query_init: bool, default = `True`
Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
if set to ``True``, :math:`\frac{Q \cdot K^T}{\sqrt{head\_dim}}`,
else :math:`Q \cdot K^T`
scaled_query_init: bool, default = True
Whether to scale WQ on initialization by :math:`\sqrt{head\_dim}`
"""
hidden_size: int = 512
......@@ -1931,7 +1992,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
attention_mask : jax.numpy.ndarray, default = None
Boolean tensor used to mask out self-attention softmax input.
:attr:`True` means mask out the corresponding values.
Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'.
Ignored when :attr:`self.self_attn_mask_type` is either ``'no_mask'`` or ``'causal'``.
encoder_decoder_mask: jax.numpy.ndarray, default = None
Boolean tensor used to mask out cross-attention softmax input when
:attr:`layer_type=TransformerLayerType.DECODER`.
......
......@@ -7,22 +7,14 @@
# pylint: disable=wrong-import-position
import functools
from packaging.version import Version as PkgVersion
import torch
from transformer_engine.common import load_framework_extension
@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release
from transformer_engine.pytorch.torch_version import torch_version
assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."
load_framework_extension("torch")
from transformer_engine.pytorch.module import LayerNormLinear
from transformer_engine.pytorch.module import Linear
......
......@@ -152,25 +152,25 @@ __all__ = ["DotProductAttention"]
class DotProductAttention(TransformerEngineBaseModule):
"""Allows the model to jointly attend to information from different
r"""Allows the model to jointly attend to information from different
representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
.. note::
Argument :attr:`attention_mask` in the `forward` call is only used when
:attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
Argument :attr:`attention_mask` in the ``forward`` call is only used when
:attr:`attn_mask_type` includes '"padding"' or ``"arbitrary"``.
.. warning::
FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
deterministic behavior at the cost of performance, use FlashAttention version >= ``2.4.1``
and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
to disable ``flash-attn`` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
.. note::
Transformer Engine stores the FP8 metadata under a `._extra_state` key when checkpointing.
Transformer Engine stores the FP8 metadata under a ``._extra_state`` key when checkpointing.
As the FP8 attention support expands from one backend to multiple backends, the location
of that key has also shifted (see `FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_).
......@@ -182,116 +182,135 @@ class DotProductAttention(TransformerEngineBaseModule):
kv_channels : Union[int, Tuple[int, int]]
the head size in key and value tensors. If the same, :attr:`kv_channels` can be
an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
num_gqa_groups : Optional[int] = None
num_gqa_groups : Optional[int], default = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the queries.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
attention_dropout: float, default = 0.0
is equivalent to MHA, i.e. ``num_gqa_groups = num_attention_heads``.
attention_dropout : float, default = 0.0
dropout probability for the dropout op during multi-head attention.
attn_mask_type: str, default = `causal`
type of attention mask passed into softmax operation, options are "`no_mask`",
"`padding`", "`causal`", "`padding,causal`", "`causal,padding`",
"`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and
"`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`"
attn_mask_type : str, default = "causal"
type of attention mask passed into softmax operation, options are ``"no_mask"``,
``"padding"``, ``"causal"``, ``"padding,causal"``, ``"causal,padding"``,
``"padding_causal"``, ``"causal_bottom_right"``, ``"padding_causal_bottom_right"``, and
``"arbitrary"``, where ``"padding,causal"``, ``"causal,padding"`` and ``"padding_causal"``
are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the
`forward` method. It is useful for cases involving compilation/tracing, e.g.
:meth:`forward` method. It is useful for cases involving compilation/tracing, e.g.
ONNX export, and the forward arg is useful for dynamically changing mask types,
e.g. a different mask for training and inference.
1. For "`no_mask`", no attention mask is applied.
2. For "`causal`", "`causal_bottom_right`", or the causal mask in
"`padding_causal`" and "`padding_causal_bottom_right`", Transformer Engine
1. For ``"no_mask"``, no attention mask is applied.
2. For ``"causal"``, ``"causal_bottom_right"``, or the causal mask in
``"padding_causal"`` and ``"padding_causal_bottom_right"``, Transformer Engine
calculates and applies an upper triangular mask to the softmax input.
No user input is needed. Causal masks without the "`bottom_right`" appendix align
No user input is needed. Causal masks without the ``"bottom_right"`` appendix align
the diagonal line to the top left corner of the softmax matrix. With
"`bottom_right`", the causal mask is aligned to the bottom right corner, which is
``"bottom_right"``, the causal mask is aligned to the bottom right corner, which is
often used in inference/KV caching.
3. For "`padding`", or the padding mask in "`padding_causal`" and
"`padding_causal_bottom_right`", users need to provide the locations of padded
tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape
[batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention
in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for
cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and
[batch_size, 1, 1, max_seqlen_kv]).
4. For "`arbitrary`", users need to provide a mask that is broadcastable to
the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
window_size: Optional[Tuple[int, int]], default = `None`
3. For ``"padding"``, or the padding mask in ``"padding_causal"`` and
``"padding_causal_bottom_right"``, users need to provide the locations of padded
tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both of shape
``[batch_size + 1]``), or via :attr:`attention_mask` (one tensor for self-attention
of shape ``[batch_size, 1, 1, max_seqlen_q]``, or two tensors in a tuple for
cross-attention of shapes ``[batch_size, 1, 1, max_seqlen_q]`` and
``[batch_size, 1, 1, max_seqlen_kv]``).
4. For ``"arbitrary"``, users need to provide a mask that is broadcastable to
the shape of softmax input ``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]``.
window_size : Optional[Tuple[int, int]], default = None
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
be overridden by :attr:`window_size` in `forward` as well.
attention_type: str, default = `self`
type of attention, either "`self`" and "`cross`".
layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules
in ``[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases ``(-1, -1)`` and ``(-1, 0)`` mean no sliding
window and causal mask specifically. Both ``causal`` and ``causal_bottom_right`` masks
map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on
``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can
be overridden by :attr:`window_size` in ``forward`` as well.
attention_type : str, default = "self"
type of attention, either ``"self"`` and ``"cross"``.
layer_number : int, default = None
layer number of the current ``DotProductAttention`` when multiple such modules
are concatenated, for instance in consecutive transformer blocks.
qkv_format: str, default = `sbhd`
dimension format for `query_layer`, `key_layer` and `value_layer`,
{`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size,
`h` the number of heads, `d` head size, and `t` the total number of tokens
in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats
qkv_format : str, default = "sbhd"
dimension format for ``query_layer``, ``key_layer`` and ``value_layer``,
{``"sbhd"``, ``"bshd"``, ``"thd"``}. ``s`` stands for the sequence length, ``b`` batch size,
``h`` the number of heads, ``d`` head size, and ``t`` the total number of tokens
in a batch, with ``t = sum(s_i), for i = 0...b-1``. ``"sbhd"`` and ``"bshd"`` formats
are used for when sequences in a batch are of equal length or padded to
equal length, and the `thd` format is used for when sequences in a batch
equal length, and the ``"thd"`` format is used for when sequences in a batch
have different lengths. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `get_qkv_layout` to gain the layout information.
softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to
`1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
tensors ``query_layer``, ``key_layer``, ``value_layer`` are laid out in memory.
For that, please use ``get_qkv_layout`` to gain the layout information.
softmax_scale : Optional[float], default = None
softmax scale for the attention scores. If ``None``, defaults to
``1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])``.
softmax_type : str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Softmax type as described in the paper
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
return_max_logit: Optional[bool], default = `False`
For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
* ``'vanilla'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
* ``'off-by-one'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
return_max_logit : Optional[bool], default = False
If true, returns the maximum attention score that can be used in a Muon optimizer to
rescale the Q and K projection weights (see `Muon is Scalable for LLM Training
<https://arxiv.org/pdf/2502.16982>`_).
max_logit = max(S), where S = mask(Q*K^T*softmax_scale + bias) in shape [b, h, s_q, s_kv],
and max_logit is in shape [h].
:math:`\text{max_logit} = \max(S)`, where :math:`S = \text{mask}(Q \cdot K^T \cdot \text{softmax_scale} + \text{bias})` of shape ``[b, h, s_q, s_kv]``,
and :math:`\text{max_logit}` is of shape ``[h]``.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
sequence_parallel : bool, default = False
if set to ``True``, uses sequence parallelism.
tp_size : int, default = 1
tensor parallel world size.
tp_group : ProcessGroup, default = `None`
tp_group : ProcessGroup, default = None
tensor parallel process group.
cp_group : Union[ProcessGroup, List[ProcessGroup]], default = `None`
cp_group : Union[ProcessGroup, List[ProcessGroup]], default = None
context parallel process group.
ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
and cp_group[1] are for a2a and p2p communications respectively.
cp_global_ranks : list of global rank IDs, default = `None`
global rank IDs of GPUs that are in cp_group.
cp_stream : CUDA stream, default = `None`
``ProcessGroup`` is for :attr:`cp_comm_type` of ``"p2p"``, ``"all_gather"``, and ``"a2a"``.
``List[ProcessGroup]`` is for :attr:`cp_comm_type` of ``"a2a+p2p"``, where :attr:`cp_group[0]`
and :attr:`cp_group[1]` are for ``"a2a"`` and ``"p2p"`` communications respectively.
cp_global_ranks : list of global rank IDs, default = None
global rank IDs of GPUs that are in ``cp_group``.
cp_stream : CUDA stream, default = None
context parallelism splits flash attention into multiple steps for
compute and communication overlapping. To address the wave quantization
issue of each split step, we add an additional CUDA stream so that we
can overlap two flash attention kernels.
cp_comm_type : str, default = `p2p`
cp_comm_type : str, default = "p2p"
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
"p2p": Exchange KV chunks with P2P communications in ring topology.
Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``.
- ``"p2p"``: Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention.
- ``"all_gather"``: All-gather to get full sequence of KV before attention.
The all-gather is not async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
- ``"a2a"``: Like DeepSpeed Ulysses, scatter attention heads across the CP
group, and gather to get full sequence of QKV.
"a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
- ``"a2a+p2p"``: hierarchical CP implementation. First applying a2a to QKV
across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink).
"""
......@@ -468,8 +487,8 @@ class DotProductAttention(TransformerEngineBaseModule):
):
"""
This function helps to load Transformer Engine 1.6 and 1.7 checkpoints, where FP8 attention
metadata is stored under the `core_attention.fused_attention._extra_state` key and not the
`core_attention._extra_state` key. Please see `FP8 checkpoint compatibility
metadata is stored under the ``core_attention.fused_attention._extra_state`` key and not the
``core_attention._extra_state`` key. Please see `FP8 checkpoint compatibility
<https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_ for more details.
"""
fused_attn_key = False
......@@ -522,23 +541,24 @@ class DotProductAttention(TransformerEngineBaseModule):
----------
cp_group : Union[ProcessGroup, List[ProcessGroup]]
context parallel process group.
ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
and cp_group[1] are for a2a and p2p communications respectively.
``ProcessGroup`` is for :attr:`cp_comm_type` of ``"p2p"``, ``"all_gather"``, and ``"a2a"``.
``List[ProcessGroup]`` is for :attr:`cp_comm_type` of ``"a2a+p2p"``, where :attr:`cp_group[0]`
and :attr:`cp_group[1]` are for ``"a2a"`` and ``"p2p"`` communications respectively.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
cp_comm_type : str, default = `p2p`
cp_comm_type : str, default = "p2p"
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
"p2p": Exchange KV chunks with P2P communications in ring topology.
Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``.
- ``"p2p"``: Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention.
- ``"all_gather"``: All-gather to get full sequence of KV before attention.
The all-gather is not async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
- ``"a2a"``: Like DeepSpeed Ulysses, scatter attention heads across the CP
group, and gather to get full sequence of QKV.
"a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
- ``"a2a+p2p"``: hierarchical CP implementation. First applying a2a to QKV
across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink).
"""
......@@ -801,13 +821,13 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8_output: Optional[bool] = False,
num_splits: Optional[int] = 1,
) -> torch.Tensor:
"""
r"""
Dot Product Attention Layer.
.. note::
Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
includes '"padding"' or `"arbitrary"`.
includes ``"padding"`` or ``"arbitrary"``.
.. note::
......@@ -846,24 +866,24 @@ class DotProductAttention(TransformerEngineBaseModule):
Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`
(which will be converted to :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`), to provide
the real sequence length information. For example, a batch of 3 sequences
[a a a b b c c c c] can be padded to [a a a PAD b b PAD PAD c c c c], and the cumulative
``[a a a b b c c c c]`` can be padded to ``[a a a PAD b b PAD PAD c c c c]``, and the cumulative
sequence length tensors would be
:attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.
:attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = ``[0, 3, 5, 9]`` for self-attention.
2. Do not perform padding on training data. Use :attr:`qkv_format` = "thd" and
:attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}.
Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`,
as in option 1. For example, a batch of 3 sequences [a a a b b c c c c] can be processed
as in option 1. For example, a batch of 3 sequences ``[a a a b b c c c c]`` can be processed
without any padding, and the sequence length tensors would be
:attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention.
:attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = ``[0, 3, 5, 9]`` for self-attention.
In certain use cases, a varying number of identifier tokens are inserted between
sequences. These tokens do not participate in the attention calculation.
:attr:`cu_seqlens_q_padded` and :attr:`cu_seqlens_kv_padded` must be specified
in such cases to correctly identify the start and end of each sequence in a batch.
For example, a batch of 3 sequences [a a a 1 b b 2 2 c c c c 3] would have
:attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9], and
:attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = [0, 4, 8, 13]
For example, a batch of 3 sequences ``[a a a 1 b b 2 2 c c c c 3]`` would have
:attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = ``[0, 3, 5, 9]``, and
:attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = ``[0, 4, 8, 13]``
for self-attention.
.. note::
......@@ -898,81 +918,81 @@ class DotProductAttention(TransformerEngineBaseModule):
value_layer : torch.Tensor
Value tensor.
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensor(s) used to mask out attention softmax input.
It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable
to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
the corresponding position is masked out and a `False` means that position
default = None. Boolean tensor(s) used to mask out attention softmax input.
It should be ``None`` for causal masks and ``"no_mask"``. For padding masks, it should be
a single tensor of ``[batch_size, 1, 1, seqlen_q]`` for self-attention, and a tuple of
two tensors of shapes ``[batch_size, 1, 1, seqlen_q]`` and ``[batch_size, 1, 1, seqlen_kv]``
for cross-attention. For ``"arbitrary"`` mask, it should be of a shape broadcastable
to ``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]``. A ``True`` value means
the corresponding position is masked out and a ``False`` means that position
is allowed to participate in attention.
qkv_format: str, default = `None`
qkv_format: str, default = None
If provided, overrides :attr:`qkv_format` from initialization.
cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
cu_seqlens_q: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (without offset) in a batch for ``query_layer``,
with shape [batch_size + 1] and dtype torch.int32.
See :ref:`note<cu_seqlens note>` for more details.
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (without offset) in a batch for ``key_layer``
and ``value_layer``, with shape [batch_size + 1] and dtype torch.int32.
See :ref:`note<cu_seqlens note>` for more details.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
cu_seqlens_q_padded: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (with offset) in a batch for
`query_layer`, with shape [batch_size + 1] and dtype torch.int32.
``query_layer``, with shape ``[batch_size + 1]`` and dtype torch.int32.
When there is no padding between sequences in a batch,
`cu_seqlens_q_padded = cu_seqlens_q`.
:attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_q`.
See :ref:`note<cu_seqlens note>` for more details.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (with offset) in a batch for ``key_layer``
and ``value_layer``, with shape ``[batch_size + 1]`` and dtype torch.int32.
When there is no padding between sequences in a batch,
`cu_seqlens_kv_padded = cu_seqlens_kv`.
:attr:`cu_seqlens_kv_padded` = :attr:`cu_seqlens_kv`.
See :ref:`note<cu_seqlens note>` for more details.
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
max_seqlen_q: Optional[int], default = None
Maximum sequence length in ``query_layer``.
See :ref:`note<max_seqlen note>` for more details.
max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`.
max_seqlen_kv: Optional[int], default = None
Maximum sequence length in ``key_layer`` and ``value_layer``.
See :ref:`note<max_seqlen note>` for more details.
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding',
'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right',
'arbitrary'}, default = `None`. Type of attention mask passed into
'arbitrary'}, default = None. Type of attention mask passed into
softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal'
are equivalent. By default, causal masks are aligned to the top left corner
of the softmax matrix. When "`bottom_right`" is specified in the mask type,
of the softmax matrix. When ``"bottom_right"`` is specified in the mask type,
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None`
window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention.
checkpoint_core_attention : bool, default = `False`
checkpoint_core_attention : bool, default = False
If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
It should be 'None' for 'no_bias' and 'alibi' bias types.
alibi_slopes: Optional[torch.Tensor], default = `None`
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
core_attention_bias_type: str, default = "no_bias"
Bias type, {``"no_bias"``, ``"pre_scale_bias"``, ``"post_scale_bias"``, ``"alibi"``}
core_attention_bias: Optional[torch.Tensor], default = None
Bias tensor for :math:`Q \cdot K^T`, shape ``[1, num_head, max_seqlen_q, max_seqlen_kv]``.
It should be ``None`` for ``"no_bias"`` and ``"alibi"`` bias types.
alibi_slopes: Optional[torch.Tensor], default = None
ALiBi slopes in FP32 and shape ``[nheads]`` or ``[batch_size, nheads]``.
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j.
fast_zero_fill: bool, default = `True`
fast_zero_fill: bool, default = True
Whether to use the fast path to set output tensors to 0 or not.
inference_params: Optional[InferenceParams], default = `None`
inference_params: Optional[InferenceParams], default = None
Optimizes execution performance during inference by caching Keys and Values of the
current decoding iteration. These cached values are appended to the K and V values
computed in previous iterations, eliminating the need to recalculate them for the
entire sequence.
Initialization of `inference_params` is required prior to use to ensure sufficient
Initialization of ``inference_params`` is required prior to use to ensure sufficient
memory allocation.
Adjustments of the sequence_len_offset should be done after a complete forward pass.
If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch.
fp8_output: Optional[bool], default = `False`
pad_between_seqs: Optional[bool], default = None
If ``None``, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If ``True``, there are padding tokens between individual sequences in a packed batch.
fp8_output: Optional[bool], default = False
Whether to enforce output to be in FP8 or not.
num_splits: Optional[int], default = 1
Optional split control for FlashAttention-3 only. When set, this value is forwarded
......
......@@ -175,65 +175,65 @@ class AttentionParams:
Parameters
----------
qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor`
qkv_type : Union[torch.Tensor, Float8Tensor], default = torch.Tensor
Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
qkv_dtype: torch.dtype, default = `torch.bfloat16`
qkv_dtype : torch.dtype, default = torch.bfloat16
Data type of query/key/value tensors.
qkv_layout: str, default = "sbh3d"
qkv_layout : str, default = "sbh3d"
Query/key/value tensor memory layout.
batch_size: int, default = 1
batch_size : int, default = 1
Batch size.
num_heads: int, default = 16
num_heads : int, default = 16
Number of attention heads in the query tensor.
num_gqa_groups: int, default = 16
num_gqa_groups : int, default = 16
Number of attention heads in key and value tensors.
max_seqlen_q: int, default = 128
max_seqlen_q : int, default = 128
Maximum sequence length of the query tensor.
max_seqlen_kv: int, default = 128
max_seqlen_kv : int, default = 128
Maximum sequence length of the key and value tensors.
head_dim_qk: int, default = 64
head_dim_qk : int, default = 64
The size of each attention head in query and key tensors.
head_dim_v: int, default = 64
head_dim_v : int, default = 64
The size of each attention head in the value tensor.
attn_mask_type: str, default = `no_mask`
attn_mask_type : str, default = no_mask
Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size: Tuple[int, int], default = None
window_size : Tuple[int, int], default = None
Sliding window attention size.
alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None
Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
core_attention_bias_type: str, default = `no_bias`
core_attention_bias_type : str, default = no_bias
Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
core_attention_bias_shape: str, default = `1hss`
core_attention_bias_shape : str, default = 1hss
Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
core_attention_bias_requires_grad: bool, default = `True`
core_attention_bias_requires_grad : bool, default = True
Whether attention bias requires gradient.
pad_between_seqs: bool, default = `False`
pad_between_seqs : bool, default = False
Whether there is padding between sequences in a batch.
This only applies to `qkv_format=thd`.
attention_dropout: float, default = 0.0
attention_dropout : float, default = 0.0
Attention dropout.
context_parallel: bool, default = `False`
context_parallel : bool, default = False
Whether context parallelism is used or not.
cp_comm_type: str, default = "p2p"
cp_comm_type : str, default = "p2p"
The communication type of context parallelism.
deterministic: bool, default = `False`
deterministic : bool, default = False
Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True`
is_training : bool, default = True
Whether in training mode (`True`) or inference mode (`False`)
fp8: bool, default = `False`
fp8 : bool, default = False
Whether `DotProductAttention` is in an `autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None`
fp8_meta : Optional[Dict[str Any]], default = None
The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None`
inference_params : Optional[InferenceParams], default = None
Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla"
softmax_type : str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details.
return_max_logit: bool, default = `False`
return_max_logit : bool, default = False
Whether to output max_logit.
cuda_graph: bool, default = `False`
cuda_graph : bool, default = `False`
Whether support for cuda graph capture is needed or not.
num_splits: int, default = 1
num_splits : int, default = 1
The number of kernels to split attention to.
"""
......@@ -298,15 +298,15 @@ def get_attention_backend(
Returns
----------
use_flash_attention: bool
use_flash_attention : bool
Whether the `FlashAttention` backend has been selected.
use_fused_attention: bool
use_fused_attention : bool
Whether the `FusedAttention` backend has been selected.
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
fused_attention_backend : tex.NVTE_Fused_Attn_Backend
If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
use_unfused_attention: bool
use_unfused_attention : bool
Whether the `UnfusedDotProductAttention` backend has been selected.
available_backends: List[bool]
available_backends : List[bool]
All available backends that could support the provided input. A list of Booleans
in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
"""
......@@ -835,8 +835,8 @@ def get_attention_backend(
# ----------------------------------------------------------------------------------------
# no_mask | None | All
# padding | | All
# self-attention | One tensor in shape [b, 1, 1, sq] |
# cross-attention | Tuple of two tensors in shapes |
# self-attention | One tensor of shape [b, 1, 1, sq] |
# cross-attention | Tuple of two tensors of shapes |
# | [b, 1, 1, sq] and [b, 1, 1, skv] |
# causal | None |
# self-attention | | All
......@@ -846,7 +846,7 @@ def get_attention_backend(
# cross-attention | | FusedAttention, UnfusedDotProductAttention
# causal_bottom_right | None | All
# padding_causal_bottom_right | Same as "padding" | All
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention
# arbitrary | One tensor of shape broadcastable to | UnfusedDotProductAttention
# | [b, h, sq, skv] |
if attn_mask_type == "arbitrary":
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
......@@ -1271,42 +1271,42 @@ def get_full_mask(
Parameters
----------
max_seqlen_q: int
max_seqlen_q : int
Maximum sequence length for queries.
max_seqlen_kv: int
max_seqlen_kv : int
Maximum sequence length for keys and values.
attn_mask_type: str, default = `no_mask`
Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
"`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
default = `None`
attn_mask_type : str, default = no_mask
Attention mask type, {``"no_mask"``, ``"padding"``, ``"causal"``, ``"padding_causal"``,
``"causal_bottom_right"``, ``"padding_causal_bottom_right"``, ``"arbitrary"``}
attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
default = None
Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
for the requirements of `attention_mask` for different `attn_mask_type`s.
window_size: Tuple[int, int], default = `None`
window_size : Tuple[int, int], default = None
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`.
attention_type: str, default = "self"
attention_type : str, default = "self"
Attention type, {"self", "cross"}
bottom_right_alignment: bool, default = `True`
bottom_right_alignment : bool, default = True
Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
specifies "causal" or "causal_bottom_right".
Returns
----------
attn_mask_type: str
attn_mask_type : str
For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
attention_mask: torch.Tensor
attention_mask : torch.Tensor
The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
actual_seqlens_q: torch.Tensor
For padding masks, the actual sequence lengths for queries, in shape [batch_size].
actual_seqlens_q : torch.Tensor
For padding masks, the actual sequence lengths for queries, of shape [batch_size].
For other masks, `None`.
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
For padding masks, the actual sequence lengths for keys and values, in shape [batch_size].
actual_seqlens_kv : Optional[torch.Tensor], default = None
For padding masks, the actual sequence lengths for keys and values, of shape [batch_size].
For other masks, `None`.
"""
# perform basic checks
......@@ -1392,29 +1392,29 @@ def get_alibi(
"""
Parameters
----------
num_heads: int
num_heads : int
Number of heads.
max_seqlen_q: int
max_seqlen_q : int
Maximum sequence length for queries.
max_seqlen_kv: int
max_seqlen_kv : int
Maximum sequence length for keys and values.
actual_seqlens_q: Optional[torch.Tensor], default = `None`
Actual sequence lengths for queries, in shape [batch_size].
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
Actual sequence lengths for keys and values, in shape [batch_size].
alibi_slopes: Optional[torch.Tensor], default = `None`
Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
bias_dtype: Optional[torch.dtype], default = `None`
actual_seqlens_q : Optional[torch.Tensor], default = None
Actual sequence lengths for queries, of shape [batch_size].
actual_seqlens_kv : Optional[torch.Tensor], default = None
Actual sequence lengths for keys and values, of shape [batch_size].
alibi_slopes : Optional[torch.Tensor], default = None
Custom ALiBi slopes, FP32, CUDA tensor, of shape [num_heads] or [batch_size, num_heads].
bias_dtype : Optional[torch.dtype], default = None
Dtype of the generated ALiBi bias. If None, use torch.float32.
bottom_right_alignment: bool, default = `True`
bottom_right_alignment : bool, default = True
Whether to align the diagonal of the ALiBi bias to the bottom right corner of
the matrix (`True`) or top left (`False`).
Returns
----------
alibi_slopes: torch.Tensor
alibi_slopes : torch.Tensor
ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
alibi_bias: torch.Tensor
alibi_bias : torch.Tensor
ALiBi bias in FP32 or `bias_dtype`. Its shape is
(1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
......@@ -1818,18 +1818,18 @@ def get_qkv_format(
Parameters
----------
qkv_layout: str
qkv_layout : str
Memory layout of `q`, `k` and `v`. See get_qkv_layout() for more details.
inference_params: InferenceParams, default = `None`
inference_params : InferenceParams, default = None
InferenceParams related to KV caching.
Returns
----------
qkv_format: str, default = `sbhd`
qkv_format : str, default = sbhd
Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}.
q_format: str
q_format : str
Format of the `q` tensor, {`bshd`, `sbhd`, `thd`}.
kv_format: str
kv_format : str
Format of the `k` and `v` tensors, {`bshd`, `sbhd`, `thd`}.
"""
splited = qkv_layout.replace("paged_kv_", "").split("_")
......@@ -1855,23 +1855,23 @@ def get_qkv_layout(
Parameters
----------
q: torch.Tensor
q : torch.Tensor
Query tensor.
k: torch.Tensor
k : torch.Tensor
Key tensor.
v: torch.Tensor
v : torch.Tensor
Value tensor.
qkv_format: str, default = `sbhd`
qkv_format : str, default = sbhd
Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
the sequence length dimension, `b` batch size, `h` the number of attention heads,
`d` head size, and `t` the total number of tokens in a batch, i.e.
`t = sum(s_i) for i = 0...b-1`.
inference_params: InferenceParams, default = `None`
inference_params : InferenceParams, default = None
InferenceParams related to KV caching.
Returns
----------
qkv_layout: str
qkv_layout : str
Memory layout of `q`, `k` and `v`. Each `qkv_layout` maps to a pair of `q_format` and
`kv_format` in {`bshd`, `sbhd`, `thd`}. The `paged_kv_` prefix is used to indicate that
paged KV caching is in play. A few examples of the layouts are as follows.
......@@ -1893,18 +1893,18 @@ def get_qkv_layout(
`thd_2bshd`: {`thd_bshd_bshd`, `paged_kv_thd_bshd_bshd`}
`thd_2sbhd`: {`thd_sbhd_sbhd`, `paged_kv_thd_sbhd_sbhd`}
q: torch.Tensor
q : torch.Tensor
Query tensor. It may be different from input `q` as we try to fit tensors to
a supported layout.
k: torch.Tensor
k : torch.Tensor
Key tensor. It may be different from input `k` as we try to fit tensors to
a supported layout.
v: torch.Tensor
v : torch.Tensor
Value tensor. It may be different from input `v` as we try to fit tensors to
a supported layout.
q_format: str
q_format : str
Format of the query tensor, {`bshd`, `sbhd`, `thd`}.
kv_format: str
kv_format : str
Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}.
"""
......
......@@ -98,29 +98,29 @@ class InferenceParams:
Parameters
----------
max_batch_size: int
max_batch_size : int
Maximum batch size in inference
max_sequence_length: int
max_sequence_length : int
Maximum sequence length in inference
num_heads_kv: int
num_heads_kv : int
Number of attention heads in keys and values
head_dim_k: int
head_dim_k : int
Head size for keys
dtype: torch.dtype
dtype : torch.dtype
Data type of the KV cache
head_dim_v: int, default = None
head_dim_v : int, default = None
Head size for values. If None, initialized as head_dim_k.
is_paged: bool, default = False
is_paged : bool, default = False
Whether the KV cache is paged (True) or non-paged (False)
total_num_pages: int, default = None
total_num_pages : int, default = None
Total number of pages in the KV cache. Required for is_paged = True.
page_size: int, default = None
page_size : int, default = None
Page size of the KV cache. Required for is_paged = True.
max_ctx_len: int, default = None
max_ctx_len : int, default = None
Maximum context length in inference. 1 <= max_ctx_len <= max_sequence_length.
qkv_format: str, default = "bshd"
qkv_format : str, default = "bshd"
Format of the incoming query/key/value tensors in current iteration
custom_cache_manager: KVCacheManager, default = None
custom_cache_manager : KVCacheManager, default = None
Custom cache manager, with KVCacheManager as the base class.
"""
......@@ -525,9 +525,9 @@ class NonPagedKVCacheManager(KVCacheManager):
new_v: torch.Tensor
New value tokens for layer_number in current inference iteration
cu_new_seqlens: torch.Tensor
Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1]
Cumulative sequence lengths for new_k and new_v, of shape [batch_size + 1]
cu_cached_seqlens: torch.Tensor
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1]
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), of shape [batch_size + 1]
qkv_format: str
Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
......@@ -701,7 +701,7 @@ class PagedKVCacheManager(KVCacheManager):
return [x.page_id for x in self.allocated_pages[seq]]
def get_page_table(self, sequences: List[int]):
"""Get the page table, in shape [batch_size, max_pages_per_seq]"""
"""Get the page table, of shape [batch_size, max_pages_per_seq]"""
page_table = torch.Tensor(
[
self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq))
......@@ -783,9 +783,9 @@ class PagedKVCacheManager(KVCacheManager):
new_v: torch.Tensor
New value tokens for layer_number in current inference iteration
cu_new_seqlens: torch.Tensor
Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1]
Cumulative sequence lengths for new_k and new_v, of shape [batch_size + 1]
cu_cached_seqlens: torch.Tensor
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1]
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), of shape [batch_size + 1]
qkv_format: str
Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
......
......@@ -50,8 +50,8 @@ class MultiheadAttention(torch.nn.Module):
.. note::
Argument :attr:`attention_mask` in the `forward` call is only used when
:attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
Argument :attr:`attention_mask` in the :meth:`forward() <MultiheadAttention.forward>` method is only used when
:attr:`attn_mask_type` includes ``"padding"`` or ``"arbitrary"``.
Parameters
----------
......@@ -59,57 +59,56 @@ class MultiheadAttention(torch.nn.Module):
size of each input sample.
num_attention_heads : int
number of attention heads in the transformer layer.
kv_channels: int, default = `None`
kv_channels : int, default = None
number of key-value channels. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
attention_dropout: float, default = 0.1
:attr:`hidden_size` / :attr:`num_attention_heads` if ``None``.
attention_dropout : float, default = 0.1
dropout probability for the dropout op during multi-head attention.
layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization
for numerical stability.
init_method : Callable, default = `None`
init_method : Callable, default = None
used for initializing weights of QKV and FC1 weights in the following way:
`init_method(weight)`. When set to `None`, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`.
output_layer_init_method : Callable, default = `None`
``init_method(weight)``. When set to ``None``, defaults to
``torch.nn.init.normal_(mean=0.0, std=0.023)``.
output_layer_init_method : Callable, default = None
used for initializing weights of PROJ and FC2 in the following way:
`output_layer_init_method(weight)`. When set to `None`, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`.
layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are
``output_layer_init_method(weight)``. When set to ``None``, defaults to
``torch.nn.init.normal_(mean=0.0, std=0.023)``.
layer_number : int, default = None
layer number of the current ``TransformerLayer`` when multiple such modules are
concatenated to form a transformer block.
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
attn_mask_type : {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
'padding_causal_bottom_right','arbitrary'},
default = `causal`
default = "causal"
type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward
:attr:`attn_mask_type` in the :meth:`forward` method. The :meth:`forward`
arg is useful for dynamically changing mask types, e.g. a different
mask for training and inference. The init arg is useful for cases
mask for training and inference. The :meth:`__init__` arg is useful for cases
involving compilation/tracing, e.g. ONNX export.
window_size: Optional[Tuple[int, int]], default = `None`
window_size : Optional[Tuple[int, int]], default = None
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
be overridden by :attr:`window_size` in `forward` as well.
num_gqa_groups : int, default = `None`
in ``[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]]`` inclusive. Special cases ``(-1, -1)`` and ``(-1, 0)`` mean no sliding
window and causal mask specifically. Both ``"causal"`` and ``"causal_bottom_right"`` masks
map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on
``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can
be overridden by :attr:`window_size` in :meth:`forward` as well.
num_gqa_groups : int, default = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward
is equivalent to MHA, i.e. ``num_gqa_groups = num_attention_heads``.
return_layernorm_output : bool, default = False
if set to ``True``, output of layernorm is returned from the :meth:`forward` method
together with the output of the linear transformation.
Example use case: residual connection for transformer module is
taken post layernorm.
input_layernorm: bool, default = `False`
if set to `True`, layer normalization to the input is applied.
attention_type: { 'self', 'cross' }, default = 'self'
input_layernorm : bool, default = False
if set to ``True``, layer normalization to the input is applied.
attention_type : { 'self', 'cross' }, default = 'self'
type of attention applied.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
......@@ -120,103 +119,118 @@ class MultiheadAttention(torch.nn.Module):
(1 + \gamma) + \beta
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
qkv_weight_interleaved : bool, default = `True`
if set to `False`, the QKV weight is interpreted as a concatenation of
query, key, and value weights along the `0th` dimension. The default
interpretation is that the individual `q`, `k`, and `v` weights for each
attention head are interleaved. This parameter is set to `False` when
qkv_weight_interleaved : bool, default = True
if set to ``False``, the QKV weight is interpreted as a concatenation of
query, key, and value weights along the ``0th`` dimension. The default
interpretation is that the individual ``q``, ``k``, and ``v`` weights for each
attention head are interleaved. This parameter is set to ``False`` when
using :attr:`fuse_qkv_params=False`.
rotary_pos_interleaved : bool, default = `False`
rotary_pos_interleaved : bool, default = False
whether to use interleaved rotary position embeddings.
bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases.
bias : bool, default = True
if set to ``False``, the transformer layer will not learn any additive biases.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
qkv_format: str, default = `sbhd`
dimension format for `query_layer`, `key_layer` and `value_layer`,
{`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size,
`h` the number of heads and `d` head size. `sbhd` and `bshd` formats
qkv_format : str, default = "sbhd"
dimension format for ``query_layer``, ``key_layer`` and ``value_layer``,
{``"sbhd"``, ``"bshd"``}. ``s`` stands for the sequence length, ``b`` batch size,
``h`` the number of heads and ``d`` head size. ``"sbhd"`` and ``"bshd"`` formats
are used for when sequences in a batch are of equal length or padded to
equal length. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None`
tensors ``query_layer``, ``key_layer``, ``value_layer`` are laid out in memory.
For that, please use ``get_qkv_layout`` to gain the layout information.
name : str, default = None
name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
softmax_type : str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
Softmax type as described in the paper
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
* ``'vanilla'``:
.. math::
S_{:,:,:,i} = = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
* ``'off-by-one'``:
.. math::
S_{:,:,:,i} = = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
S_{:,:,:,i} = = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
Parallelism parameters
----------------------
set_parallel_mode : bool, default = `False`
if set to `True`, QKV and FC1 layers are used as Column Parallel
set_parallel_mode : bool, default = False
if set to ``True``, QKV and FC1 layers are used as Column Parallel
whereas PROJ and FC2 is used as Row Parallel as described
`here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
sequence_parallel : bool, default = False
if set to ``True``, uses sequence parallelism.
tp_group : ProcessGroup, default = None
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
``set_tensor_parallel_group(tp_group)`` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
if set to ``True``, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
have an additional ``main_grad`` attribute (used instead of the
regular ``grad``) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
return_bias : bool, default = False
when set to ``True``, this module will not apply the additive bias itself, but
instead return the bias value during the :meth:`forward` method together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
fuse_qkv_params: bool, default = 'False'
if set to `True`, `TransformerLayer` module exposes a single fused
fuse_qkv_params : bool, default = 'False'
if set to ``True``, ``TransformerLayer`` module exposes a single fused
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`.
qk_norm_type: Optional[str], default = None
``fuse_wgrad_accumulation``.
qk_norm_type : Optional[str], default = None
type of normalization to apply to query and key tensors.
Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied.
When 'L2Normalization', L2 normalization is applied to query and key tensors.
When 'RMSNorm', RMS normalization is applied to query and key tensors.
When 'LayerNorm', layer normalization is applied to query and key tensors.
Options: ``None``, ``'L2Normalization'``, ``'RMSNorm'``, ``'LayerNorm'``. When ``None``, no normalization is applied.
When ``'L2Normalization'``, L2 normalization is applied to query and key tensors.
When ``'RMSNorm'``, RMS normalization is applied to query and key tensors.
When ``'LayerNorm'``, layer normalization is applied to query and key tensors.
Normalization is applied after RoPE (if applicable) but before attention computation
when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach
when ``qk_norm_before_rope`` is ``False``. This follows the e.g. Llama4 approach
for QK normalization to improve training stability and model performance.
qk_norm_eps: float, default = 1e-6
qk_norm_eps : float, default = 1e-6
epsilon value for normalization of query and key tensors.
Only used when `qk_norm_type` is not None.
qk_norm_before_rope: bool, default = `False`
if set to `True`, query and key normalization is applied before rotary position
embedding. When `False` (default), normalization is applied after RoPE.
Only used when ``qk_norm_type`` is not ``None``.
qk_norm_before_rope : bool, default = False
if set to ``True``, query and key normalization is applied before rotary position
embedding. When ``False`` (default), normalization is applied after RoPE.
This parameter allows supporting different architectural variants that apply
QK normalization at different points.
seq_length: Optional[int], default = `None`
seq_length : Optional[int], default = None
sequence length of input samples. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are used for
forward propagation and activation recompute phase.
micro_batch_size: Optional[int], default = `None`
micro_batch_size : Optional[int], default = None
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propagation and activation recompute phase.
......@@ -535,7 +549,7 @@ class MultiheadAttention(torch.nn.Module):
Parameters
----------
tp_group : ProcessGroup, default = `None`
tp_group : ProcessGroup, default = None
tensor parallel process group.
"""
self.tp_group = tp_group
......@@ -555,23 +569,24 @@ class MultiheadAttention(torch.nn.Module):
----------
cp_group : Union[ProcessGroup, List[ProcessGroup]]
context parallel process group.
ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
and cp_group[1] are for a2a and p2p communications respectively.
``ProcessGroup`` is for :attr:`cp_comm_type` of ``"p2p"``, ``"all_gather"``, and ``"a2a"``.
``List[ProcessGroup]`` is for :attr:`cp_comm_type` of ``"a2a+p2p"``, where :attr:`cp_group[0]`
and :attr:`cp_group[1]` are for ``"a2a"`` and ``"p2p"`` communications respectively.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
cp_comm_type : str, default = `p2p`
cp_comm_type : str, default = "p2p"
inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather" or "a2a", "a2a+p2p".
"p2p": Exchange KV chunks with P2P communications in ring topology.
Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``.
- ``"p2p"``: Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention.
- ``"all_gather"``: All-gather to get full sequence of KV before attention.
The all-gather is not async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
- ``"a2a"``: Like DeepSpeed Ulysses, scatter attention heads across the CP
group, and gather to get full sequence of QKV.
"a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
- ``"a2a+p2p"``: hierarchical CP implementation. First applying a2a to QKV
across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink).
"""
......@@ -622,39 +637,39 @@ class MultiheadAttention(torch.nn.Module):
fast_zero_fill: bool = True,
pad_between_seqs: Optional[bool] = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""
r"""
Forward propagation for MultiheadAttention layer.
.. note::
Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type`
includes `"padding"` or `"arbitrary"`.
includes ``"padding"`` or ``"arbitrary"``.
Parameters
----------
hidden_states : torch.Tensor
Input tensor.
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensor(s) used to mask out attention softmax input.
It should be `None` for causal masks and "`no_mask`". For padding masks, it should be
a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of
two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv]
for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to
[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means
the corresponding position is masked out and a `False` means that position
default = None. Boolean tensor(s) used to mask out attention softmax input.
It should be ``None`` for causal masks and ``"no_mask"``. For padding masks, it should be
a single tensor of ``[batch_size, 1, 1, seqlen_q]`` for self-attention, and a tuple of
two tensors of shapes ``[batch_size, 1, 1, seqlen_q]`` and ``[batch_size, 1, 1, seqlen_kv]``
for cross-attention. For ``"arbitrary"`` mask, it should be of a shape broadcastable to
``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]``. A ``True`` value means
the corresponding position is masked out and a ``False`` means that position
is allowed to participate in attention.
attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
'padding_causal_bottom_right','arbitrary'},
default = `None`
default = None
type of attention mask passed into softmax operation. By default,
causal masks are aligned to the top left corner of the softmax matrix.
When "`bottom_right`" is specified in the mask type, causal masks are
When ``"bottom_right"`` is specified in the mask type, causal masks are
aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None`
window_size: Optional[Tuple[int, int]], default = None
sliding window size for local attention.
encoder_output : Optional[torch.Tensor], default = `None`
encoder_output : Optional[torch.Tensor], default = None
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
``layer_type="decoder"``.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
......@@ -668,46 +683,46 @@ class MultiheadAttention(torch.nn.Module):
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
checkpoint_core_attention: bool, default = `False`
If true, forward activations for core attention are recomputed
checkpoint_core_attention: bool, default = False
If ``True``, forward activations for core attention are recomputed
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = None
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
It should be 'None' for 'no_bias' and 'alibi' bias types.
alibi_slopes: Optional[torch.Tensor], default = `None`
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
core_attention_bias_type: str, default = "no_bias"
Bias type, {``"no_bias"``, ``"pre_scale_bias"``, ``"post_scale_bias"``, ``"alibi"``}
core_attention_bias: Optional[torch.Tensor], default = None
Bias tensor for :math:`Q \cdot K^T`, shape ``[1, num_head, max_seqlen_q, max_seqlen_kv]``.
It should be ``None`` for ``"no_bias"`` and ``"alibi"`` bias types.
alibi_slopes: Optional[torch.Tensor], default = None
ALiBi slopes in FP32 and shape ``[nheads]`` or ``[batch_size, nheads]``.
It adds a bias of ``(-alibi_slope * (i + seqlen_k - seqlen_q - j))``
to the attention score of query i and key j.
cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided.
max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided.
fast_zero_fill: bool, default = `True`
cu_seqlens_q: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (without offset) in a batch for ``query_layer``,
with shape ``[batch_size + 1]`` and dtype torch.int32.
cu_seqlens_kv: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (without offset) in a batch for ``key_layer``
and ``value_layer``, with shape ``[batch_size + 1]`` and dtype torch.int32.
cu_seqlens_q_padded: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (with offset) in a batch for ``query_layer``,
with shape ``[batch_size + 1]`` and dtype torch.int32.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (with offset) in a batch for ``key_layer``
and ``value_layer``, with shape ``[batch_size + 1]`` and dtype torch.int32.
max_seqlen_q: Optional[int], default = None
Maximum sequence length in ``query_layer``.
Calculated from ``cu_seqlens_q`` if not provided.
max_seqlen_kv: Optional[int], default = None
Maximum sequence length in ``key_layer`` and ``value_layer``.
Calculated from ``cu_seqlens_kv`` if not provided.
fast_zero_fill: bool, default = True
Whether to set output tensors to 0 or not before use.
pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch.
pad_between_seqs: Optional[bool], default = None
If ``None``, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If ``True``, there are padding tokens between individual sequences in a packed batch.
"""
# hidden_states: [sq, b, h]
......
......@@ -287,16 +287,16 @@ def _apply_rotary_pos_emb_base(
Parameters
----------
t: torch.Tensor
t : torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional
embedding will be applied.
freqs: torch.Tensor
freqs : torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` or `[s2, b, 1, d2]`
and dtype 'float', with `s2 >= s` and `d2 <= d`.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
tensor_format : {'sbhd', 'bshd'}, default = 'sbhd'
Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
`[seq, bs, ...]`.
interleaved: bool, default = False
interleaved : bool, default = False
Whether to use interleaved rotary position embedding.
"""
# [seq, 1, 1, dim] -> [1, seq, 1, dim] or
......@@ -324,7 +324,7 @@ def _get_freqs_on_this_cp_rank(
"""Get the position embedding on the current context parallel rank.
Args:
freqs: torch.Tensor. Positional embedding tensor in shape `[s2, 1, 1, d2]`.
freqs: torch.Tensor. Positional embedding tensor of shape `[s2, 1, 1, d2]`.
seqlen: int. Length of the current sequence.
cp_size: int. Context parallel world size.
cp_rank: int. Context parallel rank.
......@@ -372,29 +372,29 @@ def apply_rotary_pos_emb(
Parameters
----------
t: torch.Tensor
t : torch.Tensor
Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied.
freqs: torch.Tensor
freqs : torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
start_positions: torch.Tensor, default = None.
start_positions : torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
tensor_format : {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
interleaved: bool, default = False
interleaved : bool, default = False
Whether to use interleaved rotary position embedding.
fused: bool, default = False
fused : bool, default = False
Whether to use a fused applying RoPE implementation.
cu_seqlens: torch.Tensor, default = None.
cu_seqlens : torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'.
Should be `cu_seqlens_padded` when cp_size > 1.
cp_size: int, default = 1.
cp_size : int, default = 1.
Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
cp_rank: int, default = 0.
cp_rank : int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
"""
assert (
......@@ -492,32 +492,32 @@ def apply_fused_qkv_rotary_pos_emb(
Parameters
----------
qkv: torch.Tensor
qkv : torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which
rotary positional embedding will be applied. This tensor has q, k, v concatenated
along the last dimension.
q_freqs: torch.Tensor
q_freqs : torch.Tensor
Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
k_freqs: torch.Tensor
k_freqs : torch.Tensor
Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
qkv_split_arg_list: List[int]
qkv_split_arg_list : List[int]
List of integers that specify the split of the qkv tensor. The list should have 3 elements,
the first element is the number of elements in the q tensor, the second element is the number
of elements in the k tensor, and the third element is the number of elements in the v tensor.
The sum of the elements in the list should be equal to the last dimension of the qkv tensor.
start_positions: torch.Tensor, default = None.
start_positions : torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
tensor_format : {'sbhd', 'bshd'}, default = 'sbhd'
is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is
of shape `[seq, bs, ...]`.
interleaved: bool, default = False
interleaved : bool, default = False
Whether to use interleaved rotary position embedding.
cp_size: int, default = 1.
cp_size : int, default = 1.
Context parallel world size.
cp_rank: int, default = 0.
cp_rank : int, default = 0.
Context parallel rank.
"""
......
......@@ -146,89 +146,89 @@ def fused_attn_fwd(
Parameters
----------
is_training: bool
is_training : bool
if True, runs training and produces auxiliary tensors aux_ctx_tensors
for the backward; if False, runs inference and doesn't produce aux_ctx_tensors
max_seqlen_q: int
max_seqlen_q : int
max sequence length for Q, used for padding;
may be larger than max(seqlens_q),
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_kv: int
max_seqlen_kv : int
max sequence length for K and V, used for padding;
may be larger than max(seqlens_kv),
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
cu_seqlens_q: torch.Tensor
cu_seqlens_q : torch.Tensor
cumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor
cu_seqlens_kv : torch.Tensor
cumulative sequence lengths for K and V; shape [batch_size + 1]
q: torch.Tensor
q : torch.Tensor
input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details)
k: torch.Tensor
k : torch.Tensor
input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details)
v: torch.Tensor
v : torch.Tensor
input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details)
fake_dtype: tex.DType
fake_dtype : tex.DType
data type of Q, K and V - in case of high precision, fake dtype in case of FP8;
in torch.dtype
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
fused_attention_backend : tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
attn_bias: torch.Tensor, default = None
attn_bias : torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
cu_seqlens_q_padded: torch.Tensor, default = None
cu_seqlens_q_padded : torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None
cu_seqlens_kv_padded : torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1]
page_table_k: torch.Tensor, default = None
page_table_k : torch.Tensor, default = None
page table for K cache; shape [batch_size, max_pages_per_seq_k]
page_table_v: torch.Tensor, default = None
page_table_v : torch.Tensor, default = None
page table for V cache; shape [batch_size, max_pages_per_seq_v]
s_quantizer: Quantizer, default = None
s_quantizer : Quantizer, default = None
Quantizer object for the intermediate value S.
o_quantizer: Quantizer, default = None
o_quantizer : Quantizer, default = None
Quantizer object for the output of the attention.
attn_scale: float, default = None
attn_scale : float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0
dropout : float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
fast_zero_fill: bool, default = True
fast_zero_fill : bool, default = True
if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "sbh3d"
qkv_layout : str, default = "sbh3d"
layout of Q, K and V;
{"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"}
attn_bias_type: str, default = "no_bias"
attn_bias_type : str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
attn_mask_type : str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
softmax_type : str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1)
window_size : Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
rng_gen: torch.Generator, default = None
rng_gen : torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
softmax_offset: torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1].
softmax_offset : torch.Tensor, default = None
softmax offset tensor of shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
return_max_logit: bool, default = False
return_max_logit : bool, default = False
whether to return the maximum attention score
cuda_graph: bool, default = False
cuda_graph : bool, default = False
whether or not cuda graph capture is enabled.
Returns
----------
o: torch.Tensor
o : torch.Tensor
output tensor O, of the attention calculation; same data type as Q, K and V;
same shape as Q
aux_ctx_tensors: List[torch.Tensor]
aux_ctx_tensors : List[torch.Tensor]
auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state]
if is_training is False, aux_ctx_tensors = None
......@@ -252,7 +252,7 @@ def fused_attn_fwd(
rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen
state of the random number generator;
[seed, offset], dtype uint64
max_logit: if return_max_logit = True, shape [h] and same data type as O; otherwise None
max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None
"""
if attn_scale is None:
......@@ -377,89 +377,89 @@ def fused_attn_bwd(
Parameters
----------
max_seqlen_q: int
max_seqlen_q : int
max sequence length for Q, used for padding; may be larger than max(seqlens_q),
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_kv: int
max_seqlen_kv : int
max sequence length for K and V, used for padding;
may be larger than max(seqlens_kv),
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
cu_seqlens_q: torch.Tensor
cu_seqlens_q : torch.Tensor
cumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor
cu_seqlens_kv : torch.Tensor
cumulative sequence lengths for K and V; shape [batch_size + 1]
q: torch.Tensor
q : torch.Tensor
input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details)
k: torch.Tensor
k : torch.Tensor
input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details)
v: torch.Tensor
v : torch.Tensor
input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details)
o: torch.Tensor
o : torch.Tensor
input tensor O (output of forward); same data type as Q, K and V;
same shape as Q
d_o: torch.Tensor
d_o : torch.Tensor
input tensor dO (gradient of O); same data type as Q, K and V;
same shape as Q
fake_dtype: tex.DType
fake_dtype : tex.DType
data type of Q, K and V - in case of high precision, fake dtype in case of FP8;
in torch.dtype
dqkv_dtype: tex.DType
dqkv_dtype : tex.DType
data type of dQ, dK and dV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor]
aux_ctx_tensors : List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
fused_attention_backend : tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
cu_seqlens_q_padded: torch.Tensor, default = None
cu_seqlens_q_padded : torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None
cu_seqlens_kv_padded : torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1]
s_quantizer: Quantizer, default = None
s_quantizer : Quantizer, default = None
Quantizer object for the intermediate value S.
dp_quantizer: Quantizer, default = None
dp_quantizer : Quantizer, default = None
Quantizer object for the intermediate value dP.
dqkv_quantizer: Quantizer, default = None
dqkv_quantizer : Quantizer, default = None
Quantizer object for the output values of the fused_attn_bwd.
dropout: float, default = 0.0
dropout : float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
fast_zero_fill: bool, default = True
fast_zero_fill : bool, default = True
if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "sbh3d"
qkv_layout : str, default = "sbh3d"
layout of Q, K and V;
{"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"}
attn_bias_type: str, default = "no_bias"
attn_bias_type : str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
attn_mask_type : str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
softmax_type : str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1)
window_size : Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
deterministic: bool, default = False
deterministic : bool, default = False
whether to execute the backward pass with deterministic behaviours.
cuda_graph: bool, default = False
cuda_graph : bool, default = False
whether or not cuda graph capture is enabled.
Returns
----------
d_q: torch.Tensor
d_q : torch.Tensor
gradient tensor of Q; same data type and shape as Q
d_k: torch.Tensor
d_k : torch.Tensor
gradient tensor of K; same data type and shape as K
d_v: torch.Tensor
d_v : torch.Tensor
gradient tensor of V; same data type and shape as V
d_bias: torch.Tensor, optional
d_bias : torch.Tensor, optional
gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias
d_softmax_offset: torch.Tensor, optional
gradient tensor of softmax offset in shape [1, h_q, 1, 1].
d_softmax_offset : torch.Tensor, optional
gradient tensor of softmax offset of shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
"""
if attn_scale is None:
......
......@@ -657,60 +657,64 @@ def get_cpu_offload_context(
Parameters
----------
enabled: bool, default = `False`
enabled : bool, default = False
When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1
num_layers : int, default = 1
Determines the number of layers
you want to offload activations/weights for.
model_layers: int, default = 1
model_layers : int, default = 1
Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True`
offload_activations : bool, default = True
Deprecated.
offload_weights: bool, default = `True`
offload_weights : bool, default = True
Deprecated.
double_buffering: bool, default = `False`
double_buffering : bool, default = False
Deprecated.
retain_pinned_cpu_buffers: bool, default = `False`
retain_pinned_cpu_buffers : bool, default = False
If True, the pinned CPU buffers are retained after offloading
and reused for the next iteration. It is useful for cuda graphs capture.
manual_synchronization: bool, default = `False`
manual_synchronization : bool, default = False
If True, the synchronization is done manually by the user.
Additional argument manual_controller is returned. See more in manual control section.
offload_stream: torch.cuda.Stream, default = `None`
offload_stream : torch.cuda.Stream, default = None
If provided, the offload stream is used for offloading and reloading.
Otherwise, a new stream is allocated internally. It can be other than None
only if manual_synchronization is True.
Manual synchronization
----------
Notes
-----
**Manual synchronization:**
By default, layers are offloaded/reloaded asynchronously
with respect to the current forward/backward stream with predefined synchronization,
to ensure that activation memory usage is equal to
`(num_layers - num_offloaded_layers) * T`, where `T` is the memory footprint of a layer.
``(num_layers - num_offloaded_layers) * T``, where ``T`` is the memory footprint of a layer.
For more control over the offloading and reloading process, you can set `manual_synchronization=True`.
In this case, an additional argument, `manual_controller`, is returned.
For more control over the offloading and reloading process, you can set ``manual_synchronization=True``.
In this case, an additional argument, ``manual_controller``, is returned.
The `manual_controller` provides the following methods:
- `start_offload_layer(layer_id: int)`
- `release_activation_forward_gpu_memory(layer_id: int)`
- `start_reload_layer(layer_id: int)`
The ``manual_controller`` provides the following methods:
- ``start_offload_layer(layer_id: int)``
- ``release_activation_forward_gpu_memory(layer_id: int)``
- ``start_reload_layer(layer_id: int)``
If none of these methods are invoked for a given layer, that layer will not be offloaded or reloaded.
If `start_offload_layer()` is called for a layer, offload copies for that layer begin asynchronously on the offload stream.
If ``start_offload_layer()`` is called for a layer, offload copies for that layer begin asynchronously on the offload stream.
Since GPU activations must be kept in memory until the copy is finished, pointers to all activations are stored.
To release this memory, you need to call `release_activation_forward_gpu_memory(layer_id)`.
To release this memory, you need to call ``release_activation_forward_gpu_memory(layer_id)``.
This method makes the current stream wait for an event recorded on the offload stream after all tensors from the layer have been offloaded.
The `start_reload_layer()` method is used to start reloading a layer.
Each tensor reload is awaited to finish before `tensor_pop()` for that tensor is called on the current stream.
The ``start_reload_layer()`` method is used to start reloading a layer.
Each tensor reload is awaited to finish before ``tensor_pop()`` for that tensor is called on the current stream.
You can provide an `offload_stream` to be used for offload and reload operations.
You can provide an ``offload_stream`` to be used for offload and reload operations.
This allows for more detailed synchronization, such as delaying the start of offloading.
Example:
**Example:**
.. code-block:: python
offload_stream = torch.cuda.Stream()
cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context(
enabled=True, model_layers=num_layers, manual_synchronization=True, offload_stream=offload_stream)
......@@ -732,10 +736,10 @@ def get_cpu_offload_context(
for i in range(num_layers):
out[i].sum().backward()
V1 code path
----------
**V1 code path:**
If you want to use the v1 code path for offloading,
please set the environment variable NVTE_CPU_OFFLOAD_V1 to 1.
please set the environment variable ``NVTE_CPU_OFFLOAD_V1`` to 1.
"""
if NVTE_CPU_OFFLOAD_V1:
......
......@@ -685,18 +685,18 @@ def get_cpu_offload_context(
Parameters
----------
enabled: bool, default = `False`
enabled : bool, default = `False`
When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1
num_layers : int, default = 1
Determines the number of transformer layers
you want to offload activations/weights for.
model_layers: int, default = 1
model_layers : int, default = 1
Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True`
offload_activations : bool, default = `True`
When set to `True`, offloads the activations for the TE layer.
offload_weights: bool, default = `True`
offload_weights : bool, default = `True`
When set to `True`, offloads the weights for the TE layer.
double_buffering: bool, default = `False`
double_buffering : bool, default = `False`
When set to `True`, uses double buffering for offloading.
"""
......
......@@ -4,6 +4,9 @@
"""Cross Entropy Loss API"""
from typing import Optional
import warnings
import torch
import transformer_engine.pytorch.triton.cross_entropy as triton_cross_entropy
......@@ -23,7 +26,7 @@ class CrossEntropyFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
_input,
inp,
target,
label_smoothing=0.0,
reduce_loss=False,
......@@ -37,7 +40,7 @@ class CrossEntropyFunction(torch.autograd.Function):
Parameters:
ctx : The context object.
_input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size.
inp (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size.
target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1].
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension.
......@@ -47,8 +50,8 @@ class CrossEntropyFunction(torch.autograd.Function):
Returns:
tensor: The computed loss.
"""
loss, _input = triton_cross_entropy.cross_entropy_forward(
_input,
loss, inp = triton_cross_entropy.cross_entropy_forward(
inp,
target,
label_smoothing,
reduce_loss,
......@@ -56,7 +59,7 @@ class CrossEntropyFunction(torch.autograd.Function):
ignore_idx,
)
ctx.save_for_backward(_input.detach())
ctx.save_for_backward(inp.detach())
ctx.is_cg_capturable = is_cg_capturable
return loss
......@@ -72,12 +75,10 @@ class CrossEntropyFunction(torch.autograd.Function):
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
(_input,) = ctx.saved_tensors
_input = triton_cross_entropy.cross_entropy_backward(
_input, grad_output, ctx.is_cg_capturable
)
(inp,) = ctx.saved_tensors
inp = triton_cross_entropy.cross_entropy_backward(inp, grad_output, ctx.is_cg_capturable)
return (
_input,
inp,
None,
None,
None,
......@@ -87,4 +88,65 @@ class CrossEntropyFunction(torch.autograd.Function):
)
parallel_cross_entropy = CrossEntropyFunction.apply
def parallel_cross_entropy(
inp: torch.Tensor,
target: torch.Tensor,
label_smoothing: float = 0.0,
reduce_loss: bool = False,
dist_process_group: Optional[torch.distributed.ProcessGroup] = None,
ignore_idx: int = -100,
is_cg_capturable: bool = False,
*,
_input: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Cross Entropy loss with optional distributed reduction.
The input tensor can be in BF16/FP32, the loss and gradient calculation happens in
FP32 only. The returned loss is always in FP32, the input gradients are upcasted
to the datatype of the input.
If ``dist_process_group`` is passed for distributed loss calculation, the input to each
distributed rank should be ``(*, V/world_size)``. Note that each of the ranks should
get equal shards along the V dimension.
Parameters
----------
inp : torch.Tensor
The input tensor of shape ``(B, SQ, V)`` or ``(SQ, B, V)`` where B is batch size,
SQ is sequence length, V is vocab size.
target : torch.Tensor
The target tensor of shape ``(B, SQ)`` or ``(SQ, B)`` where each value is in ``[0, V-1]``.
label_smoothing : float, default = 0.0
The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduce_loss : bool, default = False
If True, returns the averaged loss across the B*SQ dimension.
dist_process_group : torch.distributed.ProcessGroup, default = None
The distributed process group the loss computation is split across, None if on 1 device.
ignore_idx : int, default = -100
The index for which loss and gradients are made to zero.
is_cg_capturable : bool, default = False
Whether the operation is CUDA graph capturable.
Returns
-------
torch.Tensor
The computed loss.
"""
# Handle backward compatibility with _input parameter
if _input is not None:
warnings.warn(
"The '_input' parameter is deprecated. Please use 'inp' instead.",
FutureWarning,
)
inp = _input
return CrossEntropyFunction.apply(
inp,
target,
label_smoothing,
reduce_loss,
dist_process_group,
ignore_idx,
is_cg_capturable,
)
......@@ -30,7 +30,7 @@ except ImportError:
import transformer_engine_torch as tex
from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv
from . import torch_version
from .torch_version import torch_version
from .utils import (
is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data,
......@@ -642,18 +642,18 @@ def checkpoint(
Parameters
----------
function: Callable
function : Callable
pytorch module used to run the forward and backward passes using
the specified :attr:`args` and :attr:`kwargs`.
distribute_saved_activations: bool, default = False
if set to `True` and `use_reentrant=True`, first tensor argument is distributed
across the specified tensor parallel group (`tp_group`) before saving it for the
backward pass. This has no effect when `use_reentrant=False`.
get_rng_state_tracker: `Callable`, default = None
python callable which returns an instance of :func:`CudaRNGStatesTracker`.
distribute_saved_activations : bool, default = False
if set to ``True`` and ``use_reentrant=True``, first tensor argument is distributed
across the specified tensor parallel group (``tp_group``) before saving it for the
backward pass. This has no effect when ``use_reentrant=False``.
get_rng_state_tracker : Callable, default = None
python callable which returns an instance of :class:`CudaRNGStatesTracker`.
tp_group : ProcessGroup, default = None
tensor parallel process group. Used only when `distribute_saved_activations=True`
and `use_reentrant=True`. If `None`, it falls back to the default group.
tensor parallel process group. Used only when ``distribute_saved_activations=True``
and ``use_reentrant=True``. If ``None``, it falls back to the default group.
use_reentrant : bool, default = True
perform checkpointing in reentrant mode.
args : tuple
......@@ -778,8 +778,8 @@ class CudaRNGStatesTracker:
For model parallelism, multiple RNG states need to simultaneously exist in order
to execute operations in or out of the model parallel region. This class keeps
track of the various RNG states and provides utility methods to maintain them and
execute parts of the model under a given RNG setting. Using the `add` method, a
cuda rng state is initialized based on the input `seed` and is assigned to `name`.
execute parts of the model under a given RNG setting. Using the :meth:`add` method, a
cuda rng state is initialized based on the input ``seed`` and is assigned to ``name``.
Later, by forking the rng state, we can perform operations and return to our starting
cuda state.
"""
......@@ -812,7 +812,9 @@ class CudaRNGStatesTracker:
Set the rng states. For efficiency purposes, we do not
check the size of seed for compatibility.
states: Dict[str, torch.Tensor]
Parameters
----------
states : Dict[str, torch.Tensor]
A mapping from string names to RNG states.
"""
self.states_ = states
......@@ -821,9 +823,11 @@ class CudaRNGStatesTracker:
"""
Adds a new RNG state.
name: str
Parameters
----------
name : str
string identifier for the RNG state.
seed: int
seed : int
PyTorch seed for the RNG state.
"""
# Check seed is not already used.
......@@ -857,7 +861,9 @@ class CudaRNGStatesTracker:
Fork the cuda rng state, perform operations, and exit with
the original state.
name: str
Parameters
----------
name : str
string identifier for the RNG state.
"""
# Check if we have added the state
......@@ -2003,7 +2009,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
Parameters
----------
fsdp_root: torch.nn.Module
fsdp_root : torch.nn.Module
FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
"""
assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped."
......
......@@ -28,7 +28,7 @@ def onnx_export(enabled: bool = False) -> Generator[None, None, None]:
Parameters
----------
enabled: bool, default = `False`
enabled : bool, default = False
whether or not to enable export
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment