"""Allows the model to jointly attend to information from different
r"""Allows the model to jointly attend to information from different
representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
.. note::
Argument :attr:`attention_mask` in the `forward` call is only used when
:attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`.
Argument :attr:`attention_mask` in the ``forward`` call is only used when
:attr:`attn_mask_type` includes '"padding"' or ``"arbitrary"``.
.. warning::
FlashAttention uses a non-deterministic algorithm for optimal performance. To observe
deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1`
deterministic behavior at the cost of performance, use FlashAttention version >= ``2.4.1``
and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
to disable ``flash-attn`` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
.. note::
Transformer Engine stores the FP8 metadata under a `._extra_state` key when checkpointing.
Transformer Engine stores the FP8 metadata under a ``._extra_state`` key when checkpointing.
As the FP8 attention support expands from one backend to multiple backends, the location
of that key has also shifted (see `FP8 checkpoint compatibility <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/faq.html#fp8-checkpoint-compatibility>`_).
...
...
@@ -182,118 +182,137 @@ class DotProductAttention(TransformerEngineBaseModule):
kv_channels : Union[int, Tuple[int, int]]
the head size in key and value tensors. If the same, :attr:`kv_channels` can be
an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
num_gqa_groups : Optional[int] = None
num_gqa_groups : Optional[int], default = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the queries.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
attention_dropout: float, default = 0.0
is equivalent to MHA, i.e. ``num_gqa_groups = num_attention_heads``.
attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention.
attn_mask_type: str, default = `causal`
type of attention mask passed into softmax operation, options are "`no_mask`",
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can
be overridden by :attr:`window_size` in `forward` as well.
num_gqa_groups : int, default = `None`
in ``[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]]`` inclusive. Special cases ``(-1, -1)`` and ``(-1, 0)`` mean no sliding
window and causal mask specifically. Both ``"causal"`` and ``"causal_bottom_right"`` masks
map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on
``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can
be overridden by :attr:`window_size` in :meth:`forward` as well.
num_gqa_groups : int, default = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward
is equivalent to MHA, i.e. ``num_gqa_groups = num_attention_heads``.
return_layernorm_output : bool, default = False
if set to ``True``, output of layernorm is returned from the :meth:`forward` method
together with the output of the linear transformation.
Example use case: residual connection for transformer module is
taken post layernorm.
input_layernorm: bool, default = `False`
if set to `True`, layer normalization to the input is applied.