"""Allows the model to jointly attend to information from different
r"""Allows the model to jointly attend to information from different
representation subspaces as described in the paper:
representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
.. note::
.. note::
Argument :attr:`attention_mask` in the `forward` call is only used when
Argument :attr:`attention_mask` in the ``forward`` call is only used when
:attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
:attr:`attn_mask_type` includes '"padding"' or ``"arbitrary"``.
.. warning::
.. warning::
FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
deterministic behavior at the cost of performance, use FlashAttention version >= ``2.4.1``
and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
to disable ``flash-attn`` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
.. note::
.. note::
Transformer Engine stores the FP8 metadata under a `._extra_state` key when checkpointing.
Transformer Engine stores the FP8 metadata under a ``._extra_state`` key when checkpointing.
As the FP8 attention support expands from one backend to multiple backends, the location
As the FP8 attention support expands from one backend to multiple backends, the location
of that key has also shifted (see `FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_).
of that key has also shifted (see `FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_).
...
@@ -182,118 +182,137 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -182,118 +182,137 @@ class DotProductAttention(TransformerEngineBaseModule):
kv_channels : Union[int, Tuple[int, int]]
kv_channels : Union[int, Tuple[int, int]]
the head size in key and value tensors. If the same, :attr:`kv_channels` can be
the head size in key and value tensors. If the same, :attr:`kv_channels` can be
an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
num_gqa_groups : Optional[int] = None
num_gqa_groups : Optional[int], default = None
number of GQA groups in the transformer layer.
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the queries.
This only affects the keys and values, not the queries.
GQA-1 is equivalent to Multi-Query Attention
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
is equivalent to MHA, i.e. ``num_gqa_groups = num_attention_heads``.
attention_dropout: float, default = 0.0
attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention.
dropout probability for the dropout op during multi-head attention.
attn_mask_type: str, default = `causal`
attn_mask_type: str, default = "causal"
type of attention mask passed into softmax operation, options are "`no_mask`",
type of attention mask passed into softmax operation, options are ``"no_mask"``,
sliding window size for local attention, where query at position i attends to keys
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
in ``[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]]`` inclusive. Special cases ``(-1, -1)`` and ``(-1, 0)`` mean no sliding
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both ``"causal"`` and ``"causal_bottom_right"`` masks
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can
`attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
be overridden by :attr:`window_size` in :meth:`forward` as well.
be overridden by :attr:`window_size` in `forward` as well.
num_gqa_groups : int, default = None
num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer.
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
is equivalent to MHA, i.e. ``num_gqa_groups = num_attention_heads``.
return_layernorm_output : bool, default = `False`
return_layernorm_output : bool, default = False
if set to `True`, output of layernorm is returned from the forward
if set to ``True``, output of layernorm is returned from the :meth:`forward` method
together with the output of the linear transformation.
together with the output of the linear transformation.
Example use case: residual connection for transformer module is
Example use case: residual connection for transformer module is
taken post layernorm.
taken post layernorm.
input_layernorm: bool, default = `False`
input_layernorm: bool, default = False
if set to `True`, layer normalization to the input is applied.
if set to ``True``, layer normalization to the input is applied.