Unverified Commit df39a7c2 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

Docs fix (#2301)



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lines lenght
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* subtitle --- fix in many files:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cross entropy _input -> input rename
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cross entropy _input -> input rename
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* a lot of small fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* torch_version() change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add missing module and fix warnings
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* removed training whitespace:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update docs/api/pytorch.rst
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* Fix import
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix more imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix NumPy docstring parameter spacing and indentation

- Standardize parameter documentation to use 'param : type' format (space before and after colon) per NumPy style guide
- Fix inconsistent indentation in cpu_offload.py docstring
- Modified 51 Python files across transformer_engine/pytorch
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ca468ebe
......@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import (
Quantizer,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
QuantizeLayout,
)
......
......@@ -39,12 +39,12 @@ from ..quantize import (
Quantizer,
GroupedQuantizer,
QuantizerSet,
QuantizeLayout,
noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
get_quantize_config_with_recipe,
get_global_quantize_recipe,
QuantizeLayout,
)
from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import (
......
......@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
transpose. Note, transpose_axis should be greater than static_axis_boundary
examples:
X in shape (dim0, dim1, dim2, dim3, dim4)
X of shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1)
......
......@@ -35,9 +35,9 @@ from ..sharding import (
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import (
Quantizer,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
QuantizeLayout,
)
......
......@@ -40,11 +40,11 @@ from ..quantize import (
GroupedScaledTensor1x,
Quantizer,
GroupedQuantizer,
QuantizeLayout,
ScalingMode,
compute_scale_from_amax,
NoScaleTensor,
get_rht_matrix,
QuantizeLayout,
)
......
......@@ -21,12 +21,12 @@ from .quantize import (
ScaledTensorFactory,
ScaledTensor,
ScalingMode,
QuantizeLayout,
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported,
TensorUsage,
QuantizeLayout,
)
......
......@@ -279,26 +279,26 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization.
zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
If set to ``True``, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
The default of `scale_init` will also be changed. See `scale_init`.
This parameter is only applicable for ``'layernorm'``.
The default of ``scale_init`` will also be changed. See ``scale_init``.
scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
Otherwise, scale_init is `flax.linen.initializers.ones`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
bias_init : Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`,
only used when :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
only used when :attr:`layernorm_type='layernorm'`.
......@@ -424,15 +424,15 @@ class DenseGeneral(TransformerEngineBase):
kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes : Tuple[str, ...], default = ()
The name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False
Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias.
If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`.
......@@ -443,12 +443,12 @@ class DenseGeneral(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
:math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint.
Optimization parameters
......@@ -597,48 +597,48 @@ class LayerNormDenseGeneral(TransformerEngineBase):
epsilon : float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
If set to ``True``, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
The default of `scale_init` will also be changed. See `scale_init`
This parameter is only applicable for ``'layernorm'``.
The default of ``scale_init`` will also be changed. See ``scale_init``
scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
Otherwise, scale_init is `flax.linen.initializers.ones`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only used when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`,
only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes : Tuple[str, ...], default = ()
The name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False
Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias.
If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
If set ``False``, return ``None`` as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each dense layer.
low_rank_adaptation_dim: int, default = 32
......@@ -646,16 +646,16 @@ class LayerNormDenseGeneral(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
:math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint.
dot_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint.
Optimization parameters
......@@ -887,34 +887,34 @@ class LayerNormMLP(TransformerEngineBase):
epsilon : float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
If set to ``True``, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
The default of `scale_init` will also be changed. See `scale_init`.
This parameter is only applicable for ``'layernorm'``.
The default of ``scale_init`` will also be changed. See ``scale_init``.
scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
Otherwise, scale_init is `flax.linen.initializers.ones`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only used when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`,
only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing the weights of both dense layer transformations.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
The name of axes used to shard the weights with a corresponding mesh for
the weight of the first dense layer transformation.
......@@ -923,10 +923,10 @@ class LayerNormMLP(TransformerEngineBase):
the weight of the second dense layer transformation.
use_bias: bool, default = False
Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias.
If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes_1: Tuple[str, ...], default = ('mlp',)
The name of axes used to shard bias with a corresponding mesh for
the weight of the first dense layer transformation.
......@@ -937,7 +937,7 @@ class LayerNormMLP(TransformerEngineBase):
Only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
If set ``False``, return ``None`` as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('gelu',)
The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer.
......@@ -958,20 +958,20 @@ class LayerNormMLP(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True`.
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
:math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint.
dot_1_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 1st dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint.
dot_2_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 2nd dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint.
ffn1_ckpt_name: str = "ffn1"
Checkpoint name for the output of the first fully-connected layer in the MLP block.
......
This diff is collapsed.
......@@ -7,22 +7,14 @@
# pylint: disable=wrong-import-position
import functools
from packaging.version import Version as PkgVersion
import torch
from transformer_engine.common import load_framework_extension
@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release
from transformer_engine.pytorch.torch_version import torch_version
assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."
load_framework_extension("torch")
from transformer_engine.pytorch.module import LayerNormLinear
from transformer_engine.pytorch.module import Linear
......
......@@ -175,65 +175,65 @@ class AttentionParams:
Parameters
----------
qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor`
qkv_type : Union[torch.Tensor, Float8Tensor], default = torch.Tensor
Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
qkv_dtype: torch.dtype, default = `torch.bfloat16`
qkv_dtype : torch.dtype, default = torch.bfloat16
Data type of query/key/value tensors.
qkv_layout: str, default = "sbh3d"
qkv_layout : str, default = "sbh3d"
Query/key/value tensor memory layout.
batch_size: int, default = 1
batch_size : int, default = 1
Batch size.
num_heads: int, default = 16
num_heads : int, default = 16
Number of attention heads in the query tensor.
num_gqa_groups: int, default = 16
num_gqa_groups : int, default = 16
Number of attention heads in key and value tensors.
max_seqlen_q: int, default = 128
max_seqlen_q : int, default = 128
Maximum sequence length of the query tensor.
max_seqlen_kv: int, default = 128
max_seqlen_kv : int, default = 128
Maximum sequence length of the key and value tensors.
head_dim_qk: int, default = 64
head_dim_qk : int, default = 64
The size of each attention head in query and key tensors.
head_dim_v: int, default = 64
head_dim_v : int, default = 64
The size of each attention head in the value tensor.
attn_mask_type: str, default = `no_mask`
attn_mask_type : str, default = no_mask
Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size: Tuple[int, int], default = None
window_size : Tuple[int, int], default = None
Sliding window attention size.
alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None`
alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None
Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
core_attention_bias_type: str, default = `no_bias`
core_attention_bias_type : str, default = no_bias
Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
core_attention_bias_shape: str, default = `1hss`
core_attention_bias_shape : str, default = 1hss
Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
core_attention_bias_requires_grad: bool, default = `True`
core_attention_bias_requires_grad : bool, default = True
Whether attention bias requires gradient.
pad_between_seqs: bool, default = `False`
pad_between_seqs : bool, default = False
Whether there is padding between sequences in a batch.
This only applies to `qkv_format=thd`.
attention_dropout: float, default = 0.0
attention_dropout : float, default = 0.0
Attention dropout.
context_parallel: bool, default = `False`
context_parallel : bool, default = False
Whether context parallelism is used or not.
cp_comm_type: str, default = "p2p"
cp_comm_type : str, default = "p2p"
The communication type of context parallelism.
deterministic: bool, default = `False`
deterministic : bool, default = False
Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True`
is_training : bool, default = True
Whether in training mode (`True`) or inference mode (`False`)
fp8: bool, default = `False`
fp8 : bool, default = False
Whether `DotProductAttention` is in an `autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None`
fp8_meta : Optional[Dict[str Any]], default = None
The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None`
inference_params : Optional[InferenceParams], default = None
Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla"
softmax_type : str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details.
return_max_logit: bool, default = `False`
return_max_logit : bool, default = False
Whether to output max_logit.
cuda_graph: bool, default = `False`
cuda_graph : bool, default = `False`
Whether support for cuda graph capture is needed or not.
num_splits: int, default = 1
num_splits : int, default = 1
The number of kernels to split attention to.
"""
......@@ -298,15 +298,15 @@ def get_attention_backend(
Returns
----------
use_flash_attention: bool
use_flash_attention : bool
Whether the `FlashAttention` backend has been selected.
use_fused_attention: bool
use_fused_attention : bool
Whether the `FusedAttention` backend has been selected.
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
fused_attention_backend : tex.NVTE_Fused_Attn_Backend
If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
use_unfused_attention: bool
use_unfused_attention : bool
Whether the `UnfusedDotProductAttention` backend has been selected.
available_backends: List[bool]
available_backends : List[bool]
All available backends that could support the provided input. A list of Booleans
in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
"""
......@@ -835,8 +835,8 @@ def get_attention_backend(
# ----------------------------------------------------------------------------------------
# no_mask | None | All
# padding | | All
# self-attention | One tensor in shape [b, 1, 1, sq] |
# cross-attention | Tuple of two tensors in shapes |
# self-attention | One tensor of shape [b, 1, 1, sq] |
# cross-attention | Tuple of two tensors of shapes |
# | [b, 1, 1, sq] and [b, 1, 1, skv] |
# causal | None |
# self-attention | | All
......@@ -846,7 +846,7 @@ def get_attention_backend(
# cross-attention | | FusedAttention, UnfusedDotProductAttention
# causal_bottom_right | None | All
# padding_causal_bottom_right | Same as "padding" | All
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention
# arbitrary | One tensor of shape broadcastable to | UnfusedDotProductAttention
# | [b, h, sq, skv] |
if attn_mask_type == "arbitrary":
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
......@@ -1271,42 +1271,42 @@ def get_full_mask(
Parameters
----------
max_seqlen_q: int
max_seqlen_q : int
Maximum sequence length for queries.
max_seqlen_kv: int
max_seqlen_kv : int
Maximum sequence length for keys and values.
attn_mask_type: str, default = `no_mask`
Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
"`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
default = `None`
attn_mask_type : str, default = no_mask
Attention mask type, {``"no_mask"``, ``"padding"``, ``"causal"``, ``"padding_causal"``,
``"causal_bottom_right"``, ``"padding_causal_bottom_right"``, ``"arbitrary"``}
attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
default = None
Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
for the requirements of `attention_mask` for different `attn_mask_type`s.
window_size: Tuple[int, int], default = `None`
window_size : Tuple[int, int], default = None
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`.
attention_type: str, default = "self"
attention_type : str, default = "self"
Attention type, {"self", "cross"}
bottom_right_alignment: bool, default = `True`
bottom_right_alignment : bool, default = True
Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
specifies "causal" or "causal_bottom_right".
Returns
----------
attn_mask_type: str
attn_mask_type : str
For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
attention_mask: torch.Tensor
attention_mask : torch.Tensor
The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
actual_seqlens_q: torch.Tensor
For padding masks, the actual sequence lengths for queries, in shape [batch_size].
actual_seqlens_q : torch.Tensor
For padding masks, the actual sequence lengths for queries, of shape [batch_size].
For other masks, `None`.
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
For padding masks, the actual sequence lengths for keys and values, in shape [batch_size].
actual_seqlens_kv : Optional[torch.Tensor], default = None
For padding masks, the actual sequence lengths for keys and values, of shape [batch_size].
For other masks, `None`.
"""
# perform basic checks
......@@ -1392,29 +1392,29 @@ def get_alibi(
"""
Parameters
----------
num_heads: int
num_heads : int
Number of heads.
max_seqlen_q: int
max_seqlen_q : int
Maximum sequence length for queries.
max_seqlen_kv: int
max_seqlen_kv : int
Maximum sequence length for keys and values.
actual_seqlens_q: Optional[torch.Tensor], default = `None`
Actual sequence lengths for queries, in shape [batch_size].
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
Actual sequence lengths for keys and values, in shape [batch_size].
alibi_slopes: Optional[torch.Tensor], default = `None`
Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
bias_dtype: Optional[torch.dtype], default = `None`
actual_seqlens_q : Optional[torch.Tensor], default = None
Actual sequence lengths for queries, of shape [batch_size].
actual_seqlens_kv : Optional[torch.Tensor], default = None
Actual sequence lengths for keys and values, of shape [batch_size].
alibi_slopes : Optional[torch.Tensor], default = None
Custom ALiBi slopes, FP32, CUDA tensor, of shape [num_heads] or [batch_size, num_heads].
bias_dtype : Optional[torch.dtype], default = None
Dtype of the generated ALiBi bias. If None, use torch.float32.
bottom_right_alignment: bool, default = `True`
bottom_right_alignment : bool, default = True
Whether to align the diagonal of the ALiBi bias to the bottom right corner of
the matrix (`True`) or top left (`False`).
Returns
----------
alibi_slopes: torch.Tensor
alibi_slopes : torch.Tensor
ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
alibi_bias: torch.Tensor
alibi_bias : torch.Tensor
ALiBi bias in FP32 or `bias_dtype`. Its shape is
(1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
......@@ -1818,18 +1818,18 @@ def get_qkv_format(
Parameters
----------
qkv_layout: str
qkv_layout : str
Memory layout of `q`, `k` and `v`. See get_qkv_layout() for more details.
inference_params: InferenceParams, default = `None`
inference_params : InferenceParams, default = None
InferenceParams related to KV caching.
Returns
----------
qkv_format: str, default = `sbhd`
qkv_format : str, default = sbhd
Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}.
q_format: str
q_format : str
Format of the `q` tensor, {`bshd`, `sbhd`, `thd`}.
kv_format: str
kv_format : str
Format of the `k` and `v` tensors, {`bshd`, `sbhd`, `thd`}.
"""
splited = qkv_layout.replace("paged_kv_", "").split("_")
......@@ -1855,23 +1855,23 @@ def get_qkv_layout(
Parameters
----------
q: torch.Tensor
q : torch.Tensor
Query tensor.
k: torch.Tensor
k : torch.Tensor
Key tensor.
v: torch.Tensor
v : torch.Tensor
Value tensor.
qkv_format: str, default = `sbhd`
qkv_format : str, default = sbhd
Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
the sequence length dimension, `b` batch size, `h` the number of attention heads,
`d` head size, and `t` the total number of tokens in a batch, i.e.
`t = sum(s_i) for i = 0...b-1`.
inference_params: InferenceParams, default = `None`
inference_params : InferenceParams, default = None
InferenceParams related to KV caching.
Returns
----------
qkv_layout: str
qkv_layout : str
Memory layout of `q`, `k` and `v`. Each `qkv_layout` maps to a pair of `q_format` and
`kv_format` in {`bshd`, `sbhd`, `thd`}. The `paged_kv_` prefix is used to indicate that
paged KV caching is in play. A few examples of the layouts are as follows.
......@@ -1893,18 +1893,18 @@ def get_qkv_layout(
`thd_2bshd`: {`thd_bshd_bshd`, `paged_kv_thd_bshd_bshd`}
`thd_2sbhd`: {`thd_sbhd_sbhd`, `paged_kv_thd_sbhd_sbhd`}
q: torch.Tensor
q : torch.Tensor
Query tensor. It may be different from input `q` as we try to fit tensors to
a supported layout.
k: torch.Tensor
k : torch.Tensor
Key tensor. It may be different from input `k` as we try to fit tensors to
a supported layout.
v: torch.Tensor
v : torch.Tensor
Value tensor. It may be different from input `v` as we try to fit tensors to
a supported layout.
q_format: str
q_format : str
Format of the query tensor, {`bshd`, `sbhd`, `thd`}.
kv_format: str
kv_format : str
Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}.
"""
......
......@@ -98,29 +98,29 @@ class InferenceParams:
Parameters
----------
max_batch_size: int
max_batch_size : int
Maximum batch size in inference
max_sequence_length: int
max_sequence_length : int
Maximum sequence length in inference
num_heads_kv: int
num_heads_kv : int
Number of attention heads in keys and values
head_dim_k: int
head_dim_k : int
Head size for keys
dtype: torch.dtype
dtype : torch.dtype
Data type of the KV cache
head_dim_v: int, default = None
head_dim_v : int, default = None
Head size for values. If None, initialized as head_dim_k.
is_paged: bool, default = False
is_paged : bool, default = False
Whether the KV cache is paged (True) or non-paged (False)
total_num_pages: int, default = None
total_num_pages : int, default = None
Total number of pages in the KV cache. Required for is_paged = True.
page_size: int, default = None
page_size : int, default = None
Page size of the KV cache. Required for is_paged = True.
max_ctx_len: int, default = None
max_ctx_len : int, default = None
Maximum context length in inference. 1 <= max_ctx_len <= max_sequence_length.
qkv_format: str, default = "bshd"
qkv_format : str, default = "bshd"
Format of the incoming query/key/value tensors in current iteration
custom_cache_manager: KVCacheManager, default = None
custom_cache_manager : KVCacheManager, default = None
Custom cache manager, with KVCacheManager as the base class.
"""
......@@ -525,9 +525,9 @@ class NonPagedKVCacheManager(KVCacheManager):
new_v: torch.Tensor
New value tokens for layer_number in current inference iteration
cu_new_seqlens: torch.Tensor
Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1]
Cumulative sequence lengths for new_k and new_v, of shape [batch_size + 1]
cu_cached_seqlens: torch.Tensor
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1]
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), of shape [batch_size + 1]
qkv_format: str
Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
......@@ -701,7 +701,7 @@ class PagedKVCacheManager(KVCacheManager):
return [x.page_id for x in self.allocated_pages[seq]]
def get_page_table(self, sequences: List[int]):
"""Get the page table, in shape [batch_size, max_pages_per_seq]"""
"""Get the page table, of shape [batch_size, max_pages_per_seq]"""
page_table = torch.Tensor(
[
self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq))
......@@ -783,9 +783,9 @@ class PagedKVCacheManager(KVCacheManager):
new_v: torch.Tensor
New value tokens for layer_number in current inference iteration
cu_new_seqlens: torch.Tensor
Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1]
Cumulative sequence lengths for new_k and new_v, of shape [batch_size + 1]
cu_cached_seqlens: torch.Tensor
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1]
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), of shape [batch_size + 1]
qkv_format: str
Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
......
......@@ -287,16 +287,16 @@ def _apply_rotary_pos_emb_base(
Parameters
----------
t: torch.Tensor
t : torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional
embedding will be applied.
freqs: torch.Tensor
freqs : torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` or `[s2, b, 1, d2]`
and dtype 'float', with `s2 >= s` and `d2 <= d`.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
tensor_format : {'sbhd', 'bshd'}, default = 'sbhd'
Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
`[seq, bs, ...]`.
interleaved: bool, default = False
interleaved : bool, default = False
Whether to use interleaved rotary position embedding.
"""
# [seq, 1, 1, dim] -> [1, seq, 1, dim] or
......@@ -324,7 +324,7 @@ def _get_freqs_on_this_cp_rank(
"""Get the position embedding on the current context parallel rank.
Args:
freqs: torch.Tensor. Positional embedding tensor in shape `[s2, 1, 1, d2]`.
freqs: torch.Tensor. Positional embedding tensor of shape `[s2, 1, 1, d2]`.
seqlen: int. Length of the current sequence.
cp_size: int. Context parallel world size.
cp_rank: int. Context parallel rank.
......@@ -372,29 +372,29 @@ def apply_rotary_pos_emb(
Parameters
----------
t: torch.Tensor
t : torch.Tensor
Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied.
freqs: torch.Tensor
freqs : torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
start_positions: torch.Tensor, default = None.
start_positions : torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
tensor_format : {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
interleaved: bool, default = False
interleaved : bool, default = False
Whether to use interleaved rotary position embedding.
fused: bool, default = False
fused : bool, default = False
Whether to use a fused applying RoPE implementation.
cu_seqlens: torch.Tensor, default = None.
cu_seqlens : torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'.
Should be `cu_seqlens_padded` when cp_size > 1.
cp_size: int, default = 1.
cp_size : int, default = 1.
Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
cp_rank: int, default = 0.
cp_rank : int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
"""
assert (
......@@ -492,32 +492,32 @@ def apply_fused_qkv_rotary_pos_emb(
Parameters
----------
qkv: torch.Tensor
qkv : torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which
rotary positional embedding will be applied. This tensor has q, k, v concatenated
along the last dimension.
q_freqs: torch.Tensor
q_freqs : torch.Tensor
Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
k_freqs: torch.Tensor
k_freqs : torch.Tensor
Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
qkv_split_arg_list: List[int]
qkv_split_arg_list : List[int]
List of integers that specify the split of the qkv tensor. The list should have 3 elements,
the first element is the number of elements in the q tensor, the second element is the number
of elements in the k tensor, and the third element is the number of elements in the v tensor.
The sum of the elements in the list should be equal to the last dimension of the qkv tensor.
start_positions: torch.Tensor, default = None.
start_positions : torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
tensor_format : {'sbhd', 'bshd'}, default = 'sbhd'
is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is
of shape `[seq, bs, ...]`.
interleaved: bool, default = False
interleaved : bool, default = False
Whether to use interleaved rotary position embedding.
cp_size: int, default = 1.
cp_size : int, default = 1.
Context parallel world size.
cp_rank: int, default = 0.
cp_rank : int, default = 0.
Context parallel rank.
"""
......
......@@ -146,89 +146,89 @@ def fused_attn_fwd(
Parameters
----------
is_training: bool
is_training : bool
if True, runs training and produces auxiliary tensors aux_ctx_tensors
for the backward; if False, runs inference and doesn't produce aux_ctx_tensors
max_seqlen_q: int
max_seqlen_q : int
max sequence length for Q, used for padding;
may be larger than max(seqlens_q),
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_kv: int
max_seqlen_kv : int
max sequence length for K and V, used for padding;
may be larger than max(seqlens_kv),
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
cu_seqlens_q: torch.Tensor
cu_seqlens_q : torch.Tensor
cumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor
cu_seqlens_kv : torch.Tensor
cumulative sequence lengths for K and V; shape [batch_size + 1]
q: torch.Tensor
q : torch.Tensor
input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details)
k: torch.Tensor
k : torch.Tensor
input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details)
v: torch.Tensor
v : torch.Tensor
input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details)
fake_dtype: tex.DType
fake_dtype : tex.DType
data type of Q, K and V - in case of high precision, fake dtype in case of FP8;
in torch.dtype
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
fused_attention_backend : tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
attn_bias: torch.Tensor, default = None
attn_bias : torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
cu_seqlens_q_padded: torch.Tensor, default = None
cu_seqlens_q_padded : torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None
cu_seqlens_kv_padded : torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1]
page_table_k: torch.Tensor, default = None
page_table_k : torch.Tensor, default = None
page table for K cache; shape [batch_size, max_pages_per_seq_k]
page_table_v: torch.Tensor, default = None
page_table_v : torch.Tensor, default = None
page table for V cache; shape [batch_size, max_pages_per_seq_v]
s_quantizer: Quantizer, default = None
s_quantizer : Quantizer, default = None
Quantizer object for the intermediate value S.
o_quantizer: Quantizer, default = None
o_quantizer : Quantizer, default = None
Quantizer object for the output of the attention.
attn_scale: float, default = None
attn_scale : float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0
dropout : float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
fast_zero_fill: bool, default = True
fast_zero_fill : bool, default = True
if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "sbh3d"
qkv_layout : str, default = "sbh3d"
layout of Q, K and V;
{"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"}
attn_bias_type: str, default = "no_bias"
attn_bias_type : str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
attn_mask_type : str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
softmax_type : str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1)
window_size : Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
rng_gen: torch.Generator, default = None
rng_gen : torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
softmax_offset: torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1].
softmax_offset : torch.Tensor, default = None
softmax offset tensor of shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
return_max_logit: bool, default = False
return_max_logit : bool, default = False
whether to return the maximum attention score
cuda_graph: bool, default = False
cuda_graph : bool, default = False
whether or not cuda graph capture is enabled.
Returns
----------
o: torch.Tensor
o : torch.Tensor
output tensor O, of the attention calculation; same data type as Q, K and V;
same shape as Q
aux_ctx_tensors: List[torch.Tensor]
aux_ctx_tensors : List[torch.Tensor]
auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state]
if is_training is False, aux_ctx_tensors = None
......@@ -252,7 +252,7 @@ def fused_attn_fwd(
rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen
state of the random number generator;
[seed, offset], dtype uint64
max_logit: if return_max_logit = True, shape [h] and same data type as O; otherwise None
max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None
"""
if attn_scale is None:
......@@ -377,89 +377,89 @@ def fused_attn_bwd(
Parameters
----------
max_seqlen_q: int
max_seqlen_q : int
max sequence length for Q, used for padding; may be larger than max(seqlens_q),
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_kv: int
max_seqlen_kv : int
max sequence length for K and V, used for padding;
may be larger than max(seqlens_kv),
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
cu_seqlens_q: torch.Tensor
cu_seqlens_q : torch.Tensor
cumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor
cu_seqlens_kv : torch.Tensor
cumulative sequence lengths for K and V; shape [batch_size + 1]
q: torch.Tensor
q : torch.Tensor
input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details)
k: torch.Tensor
k : torch.Tensor
input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details)
v: torch.Tensor
v : torch.Tensor
input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details)
o: torch.Tensor
o : torch.Tensor
input tensor O (output of forward); same data type as Q, K and V;
same shape as Q
d_o: torch.Tensor
d_o : torch.Tensor
input tensor dO (gradient of O); same data type as Q, K and V;
same shape as Q
fake_dtype: tex.DType
fake_dtype : tex.DType
data type of Q, K and V - in case of high precision, fake dtype in case of FP8;
in torch.dtype
dqkv_dtype: tex.DType
dqkv_dtype : tex.DType
data type of dQ, dK and dV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor]
aux_ctx_tensors : List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
fused_attention_backend : tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
cu_seqlens_q_padded: torch.Tensor, default = None
cu_seqlens_q_padded : torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None
cu_seqlens_kv_padded : torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1]
s_quantizer: Quantizer, default = None
s_quantizer : Quantizer, default = None
Quantizer object for the intermediate value S.
dp_quantizer: Quantizer, default = None
dp_quantizer : Quantizer, default = None
Quantizer object for the intermediate value dP.
dqkv_quantizer: Quantizer, default = None
dqkv_quantizer : Quantizer, default = None
Quantizer object for the output values of the fused_attn_bwd.
dropout: float, default = 0.0
dropout : float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False
fast_zero_fill: bool, default = True
fast_zero_fill : bool, default = True
if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "sbh3d"
qkv_layout : str, default = "sbh3d"
layout of Q, K and V;
{"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"}
attn_bias_type: str, default = "no_bias"
attn_bias_type : str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
attn_mask_type : str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
softmax_type : str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1)
window_size : Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
deterministic: bool, default = False
deterministic : bool, default = False
whether to execute the backward pass with deterministic behaviours.
cuda_graph: bool, default = False
cuda_graph : bool, default = False
whether or not cuda graph capture is enabled.
Returns
----------
d_q: torch.Tensor
d_q : torch.Tensor
gradient tensor of Q; same data type and shape as Q
d_k: torch.Tensor
d_k : torch.Tensor
gradient tensor of K; same data type and shape as K
d_v: torch.Tensor
d_v : torch.Tensor
gradient tensor of V; same data type and shape as V
d_bias: torch.Tensor, optional
d_bias : torch.Tensor, optional
gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias
d_softmax_offset: torch.Tensor, optional
gradient tensor of softmax offset in shape [1, h_q, 1, 1].
d_softmax_offset : torch.Tensor, optional
gradient tensor of softmax offset of shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
"""
if attn_scale is None:
......
......@@ -657,60 +657,64 @@ def get_cpu_offload_context(
Parameters
----------
enabled: bool, default = `False`
enabled : bool, default = False
When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1
num_layers : int, default = 1
Determines the number of layers
you want to offload activations/weights for.
model_layers: int, default = 1
model_layers : int, default = 1
Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True`
offload_activations : bool, default = True
Deprecated.
offload_weights: bool, default = `True`
offload_weights : bool, default = True
Deprecated.
double_buffering: bool, default = `False`
double_buffering : bool, default = False
Deprecated.
retain_pinned_cpu_buffers: bool, default = `False`
retain_pinned_cpu_buffers : bool, default = False
If True, the pinned CPU buffers are retained after offloading
and reused for the next iteration. It is useful for cuda graphs capture.
manual_synchronization: bool, default = `False`
manual_synchronization : bool, default = False
If True, the synchronization is done manually by the user.
Additional argument manual_controller is returned. See more in manual control section.
offload_stream: torch.cuda.Stream, default = `None`
offload_stream : torch.cuda.Stream, default = None
If provided, the offload stream is used for offloading and reloading.
Otherwise, a new stream is allocated internally. It can be other than None
only if manual_synchronization is True.
Manual synchronization
----------
Notes
-----
**Manual synchronization:**
By default, layers are offloaded/reloaded asynchronously
with respect to the current forward/backward stream with predefined synchronization,
to ensure that activation memory usage is equal to
`(num_layers - num_offloaded_layers) * T`, where `T` is the memory footprint of a layer.
``(num_layers - num_offloaded_layers) * T``, where ``T`` is the memory footprint of a layer.
For more control over the offloading and reloading process, you can set `manual_synchronization=True`.
In this case, an additional argument, `manual_controller`, is returned.
For more control over the offloading and reloading process, you can set ``manual_synchronization=True``.
In this case, an additional argument, ``manual_controller``, is returned.
The `manual_controller` provides the following methods:
- `start_offload_layer(layer_id: int)`
- `release_activation_forward_gpu_memory(layer_id: int)`
- `start_reload_layer(layer_id: int)`
The ``manual_controller`` provides the following methods:
- ``start_offload_layer(layer_id: int)``
- ``release_activation_forward_gpu_memory(layer_id: int)``
- ``start_reload_layer(layer_id: int)``
If none of these methods are invoked for a given layer, that layer will not be offloaded or reloaded.
If `start_offload_layer()` is called for a layer, offload copies for that layer begin asynchronously on the offload stream.
If ``start_offload_layer()`` is called for a layer, offload copies for that layer begin asynchronously on the offload stream.
Since GPU activations must be kept in memory until the copy is finished, pointers to all activations are stored.
To release this memory, you need to call `release_activation_forward_gpu_memory(layer_id)`.
To release this memory, you need to call ``release_activation_forward_gpu_memory(layer_id)``.
This method makes the current stream wait for an event recorded on the offload stream after all tensors from the layer have been offloaded.
The `start_reload_layer()` method is used to start reloading a layer.
Each tensor reload is awaited to finish before `tensor_pop()` for that tensor is called on the current stream.
The ``start_reload_layer()`` method is used to start reloading a layer.
Each tensor reload is awaited to finish before ``tensor_pop()`` for that tensor is called on the current stream.
You can provide an `offload_stream` to be used for offload and reload operations.
You can provide an ``offload_stream`` to be used for offload and reload operations.
This allows for more detailed synchronization, such as delaying the start of offloading.
Example:
**Example:**
.. code-block:: python
offload_stream = torch.cuda.Stream()
cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context(
enabled=True, model_layers=num_layers, manual_synchronization=True, offload_stream=offload_stream)
......@@ -732,10 +736,10 @@ def get_cpu_offload_context(
for i in range(num_layers):
out[i].sum().backward()
V1 code path
----------
**V1 code path:**
If you want to use the v1 code path for offloading,
please set the environment variable NVTE_CPU_OFFLOAD_V1 to 1.
please set the environment variable ``NVTE_CPU_OFFLOAD_V1`` to 1.
"""
if NVTE_CPU_OFFLOAD_V1:
......
......@@ -685,18 +685,18 @@ def get_cpu_offload_context(
Parameters
----------
enabled: bool, default = `False`
enabled : bool, default = `False`
When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1
num_layers : int, default = 1
Determines the number of transformer layers
you want to offload activations/weights for.
model_layers: int, default = 1
model_layers : int, default = 1
Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True`
offload_activations : bool, default = `True`
When set to `True`, offloads the activations for the TE layer.
offload_weights: bool, default = `True`
offload_weights : bool, default = `True`
When set to `True`, offloads the weights for the TE layer.
double_buffering: bool, default = `False`
double_buffering : bool, default = `False`
When set to `True`, uses double buffering for offloading.
"""
......
......@@ -4,6 +4,9 @@
"""Cross Entropy Loss API"""
from typing import Optional
import warnings
import torch
import transformer_engine.pytorch.triton.cross_entropy as triton_cross_entropy
......@@ -23,7 +26,7 @@ class CrossEntropyFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
_input,
inp,
target,
label_smoothing=0.0,
reduce_loss=False,
......@@ -37,7 +40,7 @@ class CrossEntropyFunction(torch.autograd.Function):
Parameters:
ctx : The context object.
_input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size.
inp (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size.
target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1].
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension.
......@@ -47,8 +50,8 @@ class CrossEntropyFunction(torch.autograd.Function):
Returns:
tensor: The computed loss.
"""
loss, _input = triton_cross_entropy.cross_entropy_forward(
_input,
loss, inp = triton_cross_entropy.cross_entropy_forward(
inp,
target,
label_smoothing,
reduce_loss,
......@@ -56,7 +59,7 @@ class CrossEntropyFunction(torch.autograd.Function):
ignore_idx,
)
ctx.save_for_backward(_input.detach())
ctx.save_for_backward(inp.detach())
ctx.is_cg_capturable = is_cg_capturable
return loss
......@@ -72,12 +75,10 @@ class CrossEntropyFunction(torch.autograd.Function):
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
(_input,) = ctx.saved_tensors
_input = triton_cross_entropy.cross_entropy_backward(
_input, grad_output, ctx.is_cg_capturable
)
(inp,) = ctx.saved_tensors
inp = triton_cross_entropy.cross_entropy_backward(inp, grad_output, ctx.is_cg_capturable)
return (
_input,
inp,
None,
None,
None,
......@@ -87,4 +88,65 @@ class CrossEntropyFunction(torch.autograd.Function):
)
parallel_cross_entropy = CrossEntropyFunction.apply
def parallel_cross_entropy(
inp: torch.Tensor,
target: torch.Tensor,
label_smoothing: float = 0.0,
reduce_loss: bool = False,
dist_process_group: Optional[torch.distributed.ProcessGroup] = None,
ignore_idx: int = -100,
is_cg_capturable: bool = False,
*,
_input: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Cross Entropy loss with optional distributed reduction.
The input tensor can be in BF16/FP32, the loss and gradient calculation happens in
FP32 only. The returned loss is always in FP32, the input gradients are upcasted
to the datatype of the input.
If ``dist_process_group`` is passed for distributed loss calculation, the input to each
distributed rank should be ``(*, V/world_size)``. Note that each of the ranks should
get equal shards along the V dimension.
Parameters
----------
inp : torch.Tensor
The input tensor of shape ``(B, SQ, V)`` or ``(SQ, B, V)`` where B is batch size,
SQ is sequence length, V is vocab size.
target : torch.Tensor
The target tensor of shape ``(B, SQ)`` or ``(SQ, B)`` where each value is in ``[0, V-1]``.
label_smoothing : float, default = 0.0
The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduce_loss : bool, default = False
If True, returns the averaged loss across the B*SQ dimension.
dist_process_group : torch.distributed.ProcessGroup, default = None
The distributed process group the loss computation is split across, None if on 1 device.
ignore_idx : int, default = -100
The index for which loss and gradients are made to zero.
is_cg_capturable : bool, default = False
Whether the operation is CUDA graph capturable.
Returns
-------
torch.Tensor
The computed loss.
"""
# Handle backward compatibility with _input parameter
if _input is not None:
warnings.warn(
"The '_input' parameter is deprecated. Please use 'inp' instead.",
FutureWarning,
)
inp = _input
return CrossEntropyFunction.apply(
inp,
target,
label_smoothing,
reduce_loss,
dist_process_group,
ignore_idx,
is_cg_capturable,
)
......@@ -30,7 +30,7 @@ except ImportError:
import transformer_engine_torch as tex
from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv
from . import torch_version
from .torch_version import torch_version
from .utils import (
is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data,
......@@ -642,18 +642,18 @@ def checkpoint(
Parameters
----------
function: Callable
function : Callable
pytorch module used to run the forward and backward passes using
the specified :attr:`args` and :attr:`kwargs`.
distribute_saved_activations: bool, default = False
if set to `True` and `use_reentrant=True`, first tensor argument is distributed
across the specified tensor parallel group (`tp_group`) before saving it for the
backward pass. This has no effect when `use_reentrant=False`.
get_rng_state_tracker: `Callable`, default = None
python callable which returns an instance of :func:`CudaRNGStatesTracker`.
distribute_saved_activations : bool, default = False
if set to ``True`` and ``use_reentrant=True``, first tensor argument is distributed
across the specified tensor parallel group (``tp_group``) before saving it for the
backward pass. This has no effect when ``use_reentrant=False``.
get_rng_state_tracker : Callable, default = None
python callable which returns an instance of :class:`CudaRNGStatesTracker`.
tp_group : ProcessGroup, default = None
tensor parallel process group. Used only when `distribute_saved_activations=True`
and `use_reentrant=True`. If `None`, it falls back to the default group.
tensor parallel process group. Used only when ``distribute_saved_activations=True``
and ``use_reentrant=True``. If ``None``, it falls back to the default group.
use_reentrant : bool, default = True
perform checkpointing in reentrant mode.
args : tuple
......@@ -778,8 +778,8 @@ class CudaRNGStatesTracker:
For model parallelism, multiple RNG states need to simultaneously exist in order
to execute operations in or out of the model parallel region. This class keeps
track of the various RNG states and provides utility methods to maintain them and
execute parts of the model under a given RNG setting. Using the `add` method, a
cuda rng state is initialized based on the input `seed` and is assigned to `name`.
execute parts of the model under a given RNG setting. Using the :meth:`add` method, a
cuda rng state is initialized based on the input ``seed`` and is assigned to ``name``.
Later, by forking the rng state, we can perform operations and return to our starting
cuda state.
"""
......@@ -812,7 +812,9 @@ class CudaRNGStatesTracker:
Set the rng states. For efficiency purposes, we do not
check the size of seed for compatibility.
states: Dict[str, torch.Tensor]
Parameters
----------
states : Dict[str, torch.Tensor]
A mapping from string names to RNG states.
"""
self.states_ = states
......@@ -821,9 +823,11 @@ class CudaRNGStatesTracker:
"""
Adds a new RNG state.
name: str
Parameters
----------
name : str
string identifier for the RNG state.
seed: int
seed : int
PyTorch seed for the RNG state.
"""
# Check seed is not already used.
......@@ -857,7 +861,9 @@ class CudaRNGStatesTracker:
Fork the cuda rng state, perform operations, and exit with
the original state.
name: str
Parameters
----------
name : str
string identifier for the RNG state.
"""
# Check if we have added the state
......@@ -2003,7 +2009,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
Parameters
----------
fsdp_root: torch.nn.Module
fsdp_root : torch.nn.Module
FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
"""
assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped."
......
......@@ -28,7 +28,7 @@ def onnx_export(enabled: bool = False) -> Generator[None, None, None]:
Parameters
----------
enabled: bool, default = `False`
enabled : bool, default = False
whether or not to enable export
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment