"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "c2c3d540b1eca9ccbcc0fa7cb871688814a536f9"
Unverified Commit df39a7c2 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

Docs fix (#2301)



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lines lenght
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* subtitle --- fix in many files:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cross entropy _input -> input rename
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cross entropy _input -> input rename
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* a lot of small fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* torch_version() change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add missing module and fix warnings
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* removed training whitespace:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update docs/api/pytorch.rst
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* Fix import
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix more imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix NumPy docstring parameter spacing and indentation

- Standardize parameter documentation to use 'param : type' format (space before and after colon) per NumPy style guide
- Fix inconsistent indentation in cpu_offload.py docstring
- Modified 51 Python files across transformer_engine/pytorch
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ca468ebe
...@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a ...@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeLayout,
DelayedScaleQuantizer, DelayedScaleQuantizer,
ScalingMode, ScalingMode,
QuantizeLayout,
) )
......
...@@ -39,12 +39,12 @@ from ..quantize import ( ...@@ -39,12 +39,12 @@ from ..quantize import (
Quantizer, Quantizer,
GroupedQuantizer, GroupedQuantizer,
QuantizerSet, QuantizerSet,
QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv, apply_padding_to_scale_inv,
get_quantize_config_with_recipe, get_quantize_config_with_recipe,
get_global_quantize_recipe, get_global_quantize_recipe,
QuantizeLayout,
) )
from .misc import get_padded_spec, is_all_reduce_in_float32 from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import ( from ..sharding import (
......
...@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1): ...@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
transpose. Note, transpose_axis should be greater than static_axis_boundary transpose. Note, transpose_axis should be greater than static_axis_boundary
examples: examples:
X in shape (dim0, dim1, dim2, dim3, dim4) X of shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis == 2 static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1) Xt = (dim2, dim3, dim4, dim0, dim1)
......
...@@ -35,9 +35,9 @@ from ..sharding import ( ...@@ -35,9 +35,9 @@ from ..sharding import (
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeLayout,
DelayedScaleQuantizer, DelayedScaleQuantizer,
ScalingMode, ScalingMode,
QuantizeLayout,
) )
......
...@@ -40,11 +40,11 @@ from ..quantize import ( ...@@ -40,11 +40,11 @@ from ..quantize import (
GroupedScaledTensor1x, GroupedScaledTensor1x,
Quantizer, Quantizer,
GroupedQuantizer, GroupedQuantizer,
QuantizeLayout,
ScalingMode, ScalingMode,
compute_scale_from_amax, compute_scale_from_amax,
NoScaleTensor, NoScaleTensor,
get_rht_matrix, get_rht_matrix,
QuantizeLayout,
) )
......
...@@ -21,12 +21,12 @@ from .quantize import ( ...@@ -21,12 +21,12 @@ from .quantize import (
ScaledTensorFactory, ScaledTensorFactory,
ScaledTensor, ScaledTensor,
ScalingMode, ScalingMode,
QuantizeLayout,
QuantizerSet, QuantizerSet,
noop_quantizer_set, noop_quantizer_set,
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
TensorUsage, TensorUsage,
QuantizeLayout,
) )
......
...@@ -279,26 +279,26 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -279,26 +279,26 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization. Indicate the type of layer normalization.
zero_centered_gamma : bool, default = False zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to If set to ``True``, the LayerNorm formula changes to
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta (1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'. This parameter is only applicable for ``'layernorm'``.
The default of `scale_init` will also be changed. See `scale_init`. The default of ``scale_init`` will also be changed. See ``scale_init``.
scale_init : Initializer, default = None scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma. If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`. If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is `flax.linen.initializers.ones`. Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh. The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
bias_init : Initializer, default = flax.linen.initializers.zeros bias_init : Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`, Used for initializing shift factors :math:`\beta`,
only used when :attr:`layernorm_type='layernorm'`. only used when :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes : Tuple[str, ...], default = ('embed', ) bias_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
only used when :attr:`layernorm_type='layernorm'`. only used when :attr:`layernorm_type='layernorm'`.
...@@ -424,15 +424,15 @@ class DenseGeneral(TransformerEngineBase): ...@@ -424,15 +424,15 @@ class DenseGeneral(TransformerEngineBase):
kernel_init : Initializer, default = kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights. Used for initializing weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes : Tuple[str, ...], default = () kernel_axes : Tuple[str, ...], default = ()
The name of axes used to shard the weights with a corresponding mesh. The name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting. Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias. If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`. Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes: Tuple[str, ...], default = () bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh, The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`. only used when :attr:`use_bias=True`.
...@@ -443,12 +443,12 @@ class DenseGeneral(TransformerEngineBase): ...@@ -443,12 +443,12 @@ class DenseGeneral(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True` :attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output. The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
input_axes: Tuple[str, ...], default = None input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input, like Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
Optimization parameters Optimization parameters
...@@ -597,48 +597,48 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -597,48 +597,48 @@ class LayerNormDenseGeneral(TransformerEngineBase):
epsilon : float, default = 1e-6 epsilon : float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to If set to ``True``, the LayerNorm formula changes to
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta (1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'. This parameter is only applicable for ``'layernorm'``.
The default of `scale_init` will also be changed. See `scale_init` The default of ``scale_init`` will also be changed. See ``scale_init``
scale_init : Initializer, default = None scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma. If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`. If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is `flax.linen.initializers.ones`. Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh, The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only used when :attr:`enable_layernorm=True`. only used when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros ln_bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`, Used for initializing shift factors :math:`\beta`,
only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
ln_bias_axes: Tuple[str, ...], default = ('embed', ) ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default = kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights. Used for initializing weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes : Tuple[str, ...], default = () kernel_axes : Tuple[str, ...], default = ()
The name of axes used to shard the weights with a corresponding mesh. The name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting. Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias. If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`. Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes: Tuple[str, ...], default = () bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh, The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`. only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = False return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set ``False``, return ``None`` as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each dense layer. Indicate whether to enable low rank adaptation for each dense layer.
low_rank_adaptation_dim: int, default = 32 low_rank_adaptation_dim: int, default = 32
...@@ -646,16 +646,16 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -646,16 +646,16 @@ class LayerNormDenseGeneral(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True` :attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output. The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
dot_input_axes: Tuple[str, ...], default = None dot_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of dot, like Indicate the logical axes of sharding constraint to the input of dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
Optimization parameters Optimization parameters
...@@ -887,34 +887,34 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -887,34 +887,34 @@ class LayerNormMLP(TransformerEngineBase):
epsilon : float, default = 1e-6 epsilon : float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to If set to ``True``, the LayerNorm formula changes to
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot
(1 + \gamma) + \beta (1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'. This parameter is only applicable for ``'layernorm'``.
The default of `scale_init` will also be changed. See `scale_init`. The default of ``scale_init`` will also be changed. See ``scale_init``.
scale_init : Initializer, default = None scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma. If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`. If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``.
Otherwise, scale_init is `flax.linen.initializers.ones`. Otherwise, scale_init is ``flax.linen.initializers.ones``.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh, The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only used when :attr:`enable_layernorm=True`. only used when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros ln_bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing shift factors :math:`\beta`, Used for initializing shift factors :math:`\beta`,
only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
ln_bias_axes: Tuple[str, ...], default = ('embed', ) ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default = kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing the weights of both dense layer transformations. Used for initializing the weights of both dense layer transformations.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp') kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
The name of axes used to shard the weights with a corresponding mesh for The name of axes used to shard the weights with a corresponding mesh for
the weight of the first dense layer transformation. the weight of the first dense layer transformation.
...@@ -923,10 +923,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -923,10 +923,10 @@ class LayerNormMLP(TransformerEngineBase):
the weight of the second dense layer transformation. the weight of the second dense layer transformation.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting. Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias. If set to ``False``, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
Used for initializing bias, only used when :attr:`use_bias=True`. Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``.
bias_axes_1: Tuple[str, ...], default = ('mlp',) bias_axes_1: Tuple[str, ...], default = ('mlp',)
The name of axes used to shard bias with a corresponding mesh for The name of axes used to shard bias with a corresponding mesh for
the weight of the first dense layer transformation. the weight of the first dense layer transformation.
...@@ -937,7 +937,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -937,7 +937,7 @@ class LayerNormMLP(TransformerEngineBase):
Only used when :attr:`use_bias=True`. Only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = False return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set ``False``, return ``None`` as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('gelu',) activations: Sequence[Union[str, Callable]], default = ('gelu',)
The sequence of activation functions to apply after the first dense layer transformation. The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
...@@ -958,20 +958,20 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -958,20 +958,20 @@ class LayerNormMLP(TransformerEngineBase):
:attr:`enable_low_rank_adaptation=True`. :attr:`enable_low_rank_adaptation=True`.
low_rank_adaptation_alpha: float, default = None low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output. The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None layernorm_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of layernorm, like Indicate the logical axes of sharding constraint to the input of layernorm, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
dot_1_input_axes: Tuple[str, ...], default = None dot_1_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 1st dot, like Indicate the logical axes of sharding constraint to the input of 1st dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
dot_2_input_axes: Tuple[str, ...], default = None dot_2_input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input of 2nd dot, like Indicate the logical axes of sharding constraint to the input of 2nd dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert
sharding constraint. sharding constraint.
ffn1_ckpt_name: str = "ffn1" ffn1_ckpt_name: str = "ffn1"
Checkpoint name for the output of the first fully-connected layer in the MLP block. Checkpoint name for the output of the first fully-connected layer in the MLP block.
......
This diff is collapsed.
...@@ -7,22 +7,14 @@ ...@@ -7,22 +7,14 @@
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
import functools import functools
from packaging.version import Version as PkgVersion
import torch import torch
from transformer_engine.common import load_framework_extension from transformer_engine.common import load_framework_extension
from transformer_engine.pytorch.torch_version import torch_version
@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release
assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."
load_framework_extension("torch") load_framework_extension("torch")
from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import LayerNormLinear
from transformer_engine.pytorch.module import Linear from transformer_engine.pytorch.module import Linear
......
...@@ -175,65 +175,65 @@ class AttentionParams: ...@@ -175,65 +175,65 @@ class AttentionParams:
Parameters Parameters
---------- ----------
qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor` qkv_type : Union[torch.Tensor, Float8Tensor], default = torch.Tensor
Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}. Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}.
qkv_dtype: torch.dtype, default = `torch.bfloat16` qkv_dtype : torch.dtype, default = torch.bfloat16
Data type of query/key/value tensors. Data type of query/key/value tensors.
qkv_layout: str, default = "sbh3d" qkv_layout : str, default = "sbh3d"
Query/key/value tensor memory layout. Query/key/value tensor memory layout.
batch_size: int, default = 1 batch_size : int, default = 1
Batch size. Batch size.
num_heads: int, default = 16 num_heads : int, default = 16
Number of attention heads in the query tensor. Number of attention heads in the query tensor.
num_gqa_groups: int, default = 16 num_gqa_groups : int, default = 16
Number of attention heads in key and value tensors. Number of attention heads in key and value tensors.
max_seqlen_q: int, default = 128 max_seqlen_q : int, default = 128
Maximum sequence length of the query tensor. Maximum sequence length of the query tensor.
max_seqlen_kv: int, default = 128 max_seqlen_kv : int, default = 128
Maximum sequence length of the key and value tensors. Maximum sequence length of the key and value tensors.
head_dim_qk: int, default = 64 head_dim_qk : int, default = 64
The size of each attention head in query and key tensors. The size of each attention head in query and key tensors.
head_dim_v: int, default = 64 head_dim_v : int, default = 64
The size of each attention head in the value tensor. The size of each attention head in the value tensor.
attn_mask_type: str, default = `no_mask` attn_mask_type : str, default = no_mask
Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`, Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`,
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size: Tuple[int, int], default = None window_size : Tuple[int, int], default = None
Sliding window attention size. Sliding window attention size.
alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None` alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None
Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
core_attention_bias_type: str, default = `no_bias` core_attention_bias_type : str, default = no_bias
Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}. Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}.
core_attention_bias_shape: str, default = `1hss` core_attention_bias_shape : str, default = 1hss
Attention bias shape, {`1hss`, `b1ss`, `bhss`}. Attention bias shape, {`1hss`, `b1ss`, `bhss`}.
core_attention_bias_requires_grad: bool, default = `True` core_attention_bias_requires_grad : bool, default = True
Whether attention bias requires gradient. Whether attention bias requires gradient.
pad_between_seqs: bool, default = `False` pad_between_seqs : bool, default = False
Whether there is padding between sequences in a batch. Whether there is padding between sequences in a batch.
This only applies to `qkv_format=thd`. This only applies to `qkv_format=thd`.
attention_dropout: float, default = 0.0 attention_dropout : float, default = 0.0
Attention dropout. Attention dropout.
context_parallel: bool, default = `False` context_parallel : bool, default = False
Whether context parallelism is used or not. Whether context parallelism is used or not.
cp_comm_type: str, default = "p2p" cp_comm_type : str, default = "p2p"
The communication type of context parallelism. The communication type of context parallelism.
deterministic: bool, default = `False` deterministic : bool, default = False
Whether to run `DotProductAttention` with determinism or not. Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True` is_training : bool, default = True
Whether in training mode (`True`) or inference mode (`False`) Whether in training mode (`True`) or inference mode (`False`)
fp8: bool, default = `False` fp8 : bool, default = False
Whether `DotProductAttention` is in an `autocast` region. Whether `DotProductAttention` is in an `autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None` fp8_meta : Optional[Dict[str Any]], default = None
The FP8 metadata tensor of `DotProductAttention`. The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None` inference_params : Optional[InferenceParams], default = None
Inference-related parameters. See InferenceParams for details. Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla" softmax_type : str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details. The type of softmax operation. See DotProductAttention for details.
return_max_logit: bool, default = `False` return_max_logit : bool, default = False
Whether to output max_logit. Whether to output max_logit.
cuda_graph: bool, default = `False` cuda_graph : bool, default = `False`
Whether support for cuda graph capture is needed or not. Whether support for cuda graph capture is needed or not.
num_splits: int, default = 1 num_splits : int, default = 1
The number of kernels to split attention to. The number of kernels to split attention to.
""" """
...@@ -298,15 +298,15 @@ def get_attention_backend( ...@@ -298,15 +298,15 @@ def get_attention_backend(
Returns Returns
---------- ----------
use_flash_attention: bool use_flash_attention : bool
Whether the `FlashAttention` backend has been selected. Whether the `FlashAttention` backend has been selected.
use_fused_attention: bool use_fused_attention : bool
Whether the `FusedAttention` backend has been selected. Whether the `FusedAttention` backend has been selected.
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend : tex.NVTE_Fused_Attn_Backend
If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`. If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`.
use_unfused_attention: bool use_unfused_attention : bool
Whether the `UnfusedDotProductAttention` backend has been selected. Whether the `UnfusedDotProductAttention` backend has been selected.
available_backends: List[bool] available_backends : List[bool]
All available backends that could support the provided input. A list of Booleans All available backends that could support the provided input. A list of Booleans
in the form of [use_flash_attention, use_fused_attention, use_unfused_attention]. in the form of [use_flash_attention, use_fused_attention, use_unfused_attention].
""" """
...@@ -835,8 +835,8 @@ def get_attention_backend( ...@@ -835,8 +835,8 @@ def get_attention_backend(
# ---------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------
# no_mask | None | All # no_mask | None | All
# padding | | All # padding | | All
# self-attention | One tensor in shape [b, 1, 1, sq] | # self-attention | One tensor of shape [b, 1, 1, sq] |
# cross-attention | Tuple of two tensors in shapes | # cross-attention | Tuple of two tensors of shapes |
# | [b, 1, 1, sq] and [b, 1, 1, skv] | # | [b, 1, 1, sq] and [b, 1, 1, skv] |
# causal | None | # causal | None |
# self-attention | | All # self-attention | | All
...@@ -846,7 +846,7 @@ def get_attention_backend( ...@@ -846,7 +846,7 @@ def get_attention_backend(
# cross-attention | | FusedAttention, UnfusedDotProductAttention # cross-attention | | FusedAttention, UnfusedDotProductAttention
# causal_bottom_right | None | All # causal_bottom_right | None | All
# padding_causal_bottom_right | Same as "padding" | All # padding_causal_bottom_right | Same as "padding" | All
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention # arbitrary | One tensor of shape broadcastable to | UnfusedDotProductAttention
# | [b, h, sq, skv] | # | [b, h, sq, skv] |
if attn_mask_type == "arbitrary": if attn_mask_type == "arbitrary":
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
...@@ -1271,42 +1271,42 @@ def get_full_mask( ...@@ -1271,42 +1271,42 @@ def get_full_mask(
Parameters Parameters
---------- ----------
max_seqlen_q: int max_seqlen_q : int
Maximum sequence length for queries. Maximum sequence length for queries.
max_seqlen_kv: int max_seqlen_kv : int
Maximum sequence length for keys and values. Maximum sequence length for keys and values.
attn_mask_type: str, default = `no_mask` attn_mask_type : str, default = no_mask
Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`", Attention mask type, {``"no_mask"``, ``"padding"``, ``"causal"``, ``"padding_causal"``,
"`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"} ``"causal_bottom_right"``, ``"padding_causal_bottom_right"``, ``"arbitrary"``}
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], attention_mask : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
default = `None` default = None
Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
for the requirements of `attention_mask` for different `attn_mask_type`s. for the requirements of `attention_mask` for different `attn_mask_type`s.
window_size: Tuple[int, int], default = `None` window_size : Tuple[int, int], default = None
Sliding window size for local attention, where query at position i attends to keys Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`. `attn_mask_type`.
attention_type: str, default = "self" attention_type : str, default = "self"
Attention type, {"self", "cross"} Attention type, {"self", "cross"}
bottom_right_alignment: bool, default = `True` bottom_right_alignment : bool, default = True
Whether to align the diagonal of the sliding window attention to the bottom right (`True`) Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
specifies "causal" or "causal_bottom_right". specifies "causal" or "causal_bottom_right".
Returns Returns
---------- ----------
attn_mask_type: str attn_mask_type : str
For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type` For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
attention_mask: torch.Tensor attention_mask : torch.Tensor
The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size` The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
actual_seqlens_q: torch.Tensor actual_seqlens_q : torch.Tensor
For padding masks, the actual sequence lengths for queries, in shape [batch_size]. For padding masks, the actual sequence lengths for queries, of shape [batch_size].
For other masks, `None`. For other masks, `None`.
actual_seqlens_kv: Optional[torch.Tensor], default = `None` actual_seqlens_kv : Optional[torch.Tensor], default = None
For padding masks, the actual sequence lengths for keys and values, in shape [batch_size]. For padding masks, the actual sequence lengths for keys and values, of shape [batch_size].
For other masks, `None`. For other masks, `None`.
""" """
# perform basic checks # perform basic checks
...@@ -1392,29 +1392,29 @@ def get_alibi( ...@@ -1392,29 +1392,29 @@ def get_alibi(
""" """
Parameters Parameters
---------- ----------
num_heads: int num_heads : int
Number of heads. Number of heads.
max_seqlen_q: int max_seqlen_q : int
Maximum sequence length for queries. Maximum sequence length for queries.
max_seqlen_kv: int max_seqlen_kv : int
Maximum sequence length for keys and values. Maximum sequence length for keys and values.
actual_seqlens_q: Optional[torch.Tensor], default = `None` actual_seqlens_q : Optional[torch.Tensor], default = None
Actual sequence lengths for queries, in shape [batch_size]. Actual sequence lengths for queries, of shape [batch_size].
actual_seqlens_kv: Optional[torch.Tensor], default = `None` actual_seqlens_kv : Optional[torch.Tensor], default = None
Actual sequence lengths for keys and values, in shape [batch_size]. Actual sequence lengths for keys and values, of shape [batch_size].
alibi_slopes: Optional[torch.Tensor], default = `None` alibi_slopes : Optional[torch.Tensor], default = None
Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. Custom ALiBi slopes, FP32, CUDA tensor, of shape [num_heads] or [batch_size, num_heads].
bias_dtype: Optional[torch.dtype], default = `None` bias_dtype : Optional[torch.dtype], default = None
Dtype of the generated ALiBi bias. If None, use torch.float32. Dtype of the generated ALiBi bias. If None, use torch.float32.
bottom_right_alignment: bool, default = `True` bottom_right_alignment : bool, default = True
Whether to align the diagonal of the ALiBi bias to the bottom right corner of Whether to align the diagonal of the ALiBi bias to the bottom right corner of
the matrix (`True`) or top left (`False`). the matrix (`True`) or top left (`False`).
Returns Returns
---------- ----------
alibi_slopes: torch.Tensor alibi_slopes : torch.Tensor
ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads]. ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
alibi_bias: torch.Tensor alibi_bias : torch.Tensor
ALiBi bias in FP32 or `bias_dtype`. Its shape is ALiBi bias in FP32 or `bias_dtype`. Its shape is
(1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape, (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape,
and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or
...@@ -1818,18 +1818,18 @@ def get_qkv_format( ...@@ -1818,18 +1818,18 @@ def get_qkv_format(
Parameters Parameters
---------- ----------
qkv_layout: str qkv_layout : str
Memory layout of `q`, `k` and `v`. See get_qkv_layout() for more details. Memory layout of `q`, `k` and `v`. See get_qkv_layout() for more details.
inference_params: InferenceParams, default = `None` inference_params : InferenceParams, default = None
InferenceParams related to KV caching. InferenceParams related to KV caching.
Returns Returns
---------- ----------
qkv_format: str, default = `sbhd` qkv_format : str, default = sbhd
Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}.
q_format: str q_format : str
Format of the `q` tensor, {`bshd`, `sbhd`, `thd`}. Format of the `q` tensor, {`bshd`, `sbhd`, `thd`}.
kv_format: str kv_format : str
Format of the `k` and `v` tensors, {`bshd`, `sbhd`, `thd`}. Format of the `k` and `v` tensors, {`bshd`, `sbhd`, `thd`}.
""" """
splited = qkv_layout.replace("paged_kv_", "").split("_") splited = qkv_layout.replace("paged_kv_", "").split("_")
...@@ -1855,23 +1855,23 @@ def get_qkv_layout( ...@@ -1855,23 +1855,23 @@ def get_qkv_layout(
Parameters Parameters
---------- ----------
q: torch.Tensor q : torch.Tensor
Query tensor. Query tensor.
k: torch.Tensor k : torch.Tensor
Key tensor. Key tensor.
v: torch.Tensor v : torch.Tensor
Value tensor. Value tensor.
qkv_format: str, default = `sbhd` qkv_format : str, default = sbhd
Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for
the sequence length dimension, `b` batch size, `h` the number of attention heads, the sequence length dimension, `b` batch size, `h` the number of attention heads,
`d` head size, and `t` the total number of tokens in a batch, i.e. `d` head size, and `t` the total number of tokens in a batch, i.e.
`t = sum(s_i) for i = 0...b-1`. `t = sum(s_i) for i = 0...b-1`.
inference_params: InferenceParams, default = `None` inference_params : InferenceParams, default = None
InferenceParams related to KV caching. InferenceParams related to KV caching.
Returns Returns
---------- ----------
qkv_layout: str qkv_layout : str
Memory layout of `q`, `k` and `v`. Each `qkv_layout` maps to a pair of `q_format` and Memory layout of `q`, `k` and `v`. Each `qkv_layout` maps to a pair of `q_format` and
`kv_format` in {`bshd`, `sbhd`, `thd`}. The `paged_kv_` prefix is used to indicate that `kv_format` in {`bshd`, `sbhd`, `thd`}. The `paged_kv_` prefix is used to indicate that
paged KV caching is in play. A few examples of the layouts are as follows. paged KV caching is in play. A few examples of the layouts are as follows.
...@@ -1893,18 +1893,18 @@ def get_qkv_layout( ...@@ -1893,18 +1893,18 @@ def get_qkv_layout(
`thd_2bshd`: {`thd_bshd_bshd`, `paged_kv_thd_bshd_bshd`} `thd_2bshd`: {`thd_bshd_bshd`, `paged_kv_thd_bshd_bshd`}
`thd_2sbhd`: {`thd_sbhd_sbhd`, `paged_kv_thd_sbhd_sbhd`} `thd_2sbhd`: {`thd_sbhd_sbhd`, `paged_kv_thd_sbhd_sbhd`}
q: torch.Tensor q : torch.Tensor
Query tensor. It may be different from input `q` as we try to fit tensors to Query tensor. It may be different from input `q` as we try to fit tensors to
a supported layout. a supported layout.
k: torch.Tensor k : torch.Tensor
Key tensor. It may be different from input `k` as we try to fit tensors to Key tensor. It may be different from input `k` as we try to fit tensors to
a supported layout. a supported layout.
v: torch.Tensor v : torch.Tensor
Value tensor. It may be different from input `v` as we try to fit tensors to Value tensor. It may be different from input `v` as we try to fit tensors to
a supported layout. a supported layout.
q_format: str q_format : str
Format of the query tensor, {`bshd`, `sbhd`, `thd`}. Format of the query tensor, {`bshd`, `sbhd`, `thd`}.
kv_format: str kv_format : str
Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}. Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}.
""" """
......
...@@ -98,29 +98,29 @@ class InferenceParams: ...@@ -98,29 +98,29 @@ class InferenceParams:
Parameters Parameters
---------- ----------
max_batch_size: int max_batch_size : int
Maximum batch size in inference Maximum batch size in inference
max_sequence_length: int max_sequence_length : int
Maximum sequence length in inference Maximum sequence length in inference
num_heads_kv: int num_heads_kv : int
Number of attention heads in keys and values Number of attention heads in keys and values
head_dim_k: int head_dim_k : int
Head size for keys Head size for keys
dtype: torch.dtype dtype : torch.dtype
Data type of the KV cache Data type of the KV cache
head_dim_v: int, default = None head_dim_v : int, default = None
Head size for values. If None, initialized as head_dim_k. Head size for values. If None, initialized as head_dim_k.
is_paged: bool, default = False is_paged : bool, default = False
Whether the KV cache is paged (True) or non-paged (False) Whether the KV cache is paged (True) or non-paged (False)
total_num_pages: int, default = None total_num_pages : int, default = None
Total number of pages in the KV cache. Required for is_paged = True. Total number of pages in the KV cache. Required for is_paged = True.
page_size: int, default = None page_size : int, default = None
Page size of the KV cache. Required for is_paged = True. Page size of the KV cache. Required for is_paged = True.
max_ctx_len: int, default = None max_ctx_len : int, default = None
Maximum context length in inference. 1 <= max_ctx_len <= max_sequence_length. Maximum context length in inference. 1 <= max_ctx_len <= max_sequence_length.
qkv_format: str, default = "bshd" qkv_format : str, default = "bshd"
Format of the incoming query/key/value tensors in current iteration Format of the incoming query/key/value tensors in current iteration
custom_cache_manager: KVCacheManager, default = None custom_cache_manager : KVCacheManager, default = None
Custom cache manager, with KVCacheManager as the base class. Custom cache manager, with KVCacheManager as the base class.
""" """
...@@ -525,9 +525,9 @@ class NonPagedKVCacheManager(KVCacheManager): ...@@ -525,9 +525,9 @@ class NonPagedKVCacheManager(KVCacheManager):
new_v: torch.Tensor new_v: torch.Tensor
New value tokens for layer_number in current inference iteration New value tokens for layer_number in current inference iteration
cu_new_seqlens: torch.Tensor cu_new_seqlens: torch.Tensor
Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] Cumulative sequence lengths for new_k and new_v, of shape [batch_size + 1]
cu_cached_seqlens: torch.Tensor cu_cached_seqlens: torch.Tensor
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), of shape [batch_size + 1]
qkv_format: str qkv_format: str
Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
...@@ -701,7 +701,7 @@ class PagedKVCacheManager(KVCacheManager): ...@@ -701,7 +701,7 @@ class PagedKVCacheManager(KVCacheManager):
return [x.page_id for x in self.allocated_pages[seq]] return [x.page_id for x in self.allocated_pages[seq]]
def get_page_table(self, sequences: List[int]): def get_page_table(self, sequences: List[int]):
"""Get the page table, in shape [batch_size, max_pages_per_seq]""" """Get the page table, of shape [batch_size, max_pages_per_seq]"""
page_table = torch.Tensor( page_table = torch.Tensor(
[ [
self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq)) self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq))
...@@ -783,9 +783,9 @@ class PagedKVCacheManager(KVCacheManager): ...@@ -783,9 +783,9 @@ class PagedKVCacheManager(KVCacheManager):
new_v: torch.Tensor new_v: torch.Tensor
New value tokens for layer_number in current inference iteration New value tokens for layer_number in current inference iteration
cu_new_seqlens: torch.Tensor cu_new_seqlens: torch.Tensor
Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] Cumulative sequence lengths for new_k and new_v, of shape [batch_size + 1]
cu_cached_seqlens: torch.Tensor cu_cached_seqlens: torch.Tensor
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), of shape [batch_size + 1]
qkv_format: str qkv_format: str
Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
......
...@@ -287,16 +287,16 @@ def _apply_rotary_pos_emb_base( ...@@ -287,16 +287,16 @@ def _apply_rotary_pos_emb_base(
Parameters Parameters
---------- ----------
t: torch.Tensor t : torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional
embedding will be applied. embedding will be applied.
freqs: torch.Tensor freqs : torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` or `[s2, b, 1, d2]` Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` or `[s2, b, 1, d2]`
and dtype 'float', with `s2 >= s` and `d2 <= d`. and dtype 'float', with `s2 >= s` and `d2 <= d`.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' tensor_format : {'sbhd', 'bshd'}, default = 'sbhd'
Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
`[seq, bs, ...]`. `[seq, bs, ...]`.
interleaved: bool, default = False interleaved : bool, default = False
Whether to use interleaved rotary position embedding. Whether to use interleaved rotary position embedding.
""" """
# [seq, 1, 1, dim] -> [1, seq, 1, dim] or # [seq, 1, 1, dim] -> [1, seq, 1, dim] or
...@@ -324,7 +324,7 @@ def _get_freqs_on_this_cp_rank( ...@@ -324,7 +324,7 @@ def _get_freqs_on_this_cp_rank(
"""Get the position embedding on the current context parallel rank. """Get the position embedding on the current context parallel rank.
Args: Args:
freqs: torch.Tensor. Positional embedding tensor in shape `[s2, 1, 1, d2]`. freqs: torch.Tensor. Positional embedding tensor of shape `[s2, 1, 1, d2]`.
seqlen: int. Length of the current sequence. seqlen: int. Length of the current sequence.
cp_size: int. Context parallel world size. cp_size: int. Context parallel world size.
cp_rank: int. Context parallel rank. cp_rank: int. Context parallel rank.
...@@ -372,29 +372,29 @@ def apply_rotary_pos_emb( ...@@ -372,29 +372,29 @@ def apply_rotary_pos_emb(
Parameters Parameters
---------- ----------
t: torch.Tensor t : torch.Tensor
Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied. rotary positional embedding will be applied.
freqs: torch.Tensor freqs : torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`. with `s2 >= s` and `d2 <= d`.
start_positions: torch.Tensor, default = None. start_positions : torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset. `start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' tensor_format : {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True. of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
interleaved: bool, default = False interleaved : bool, default = False
Whether to use interleaved rotary position embedding. Whether to use interleaved rotary position embedding.
fused: bool, default = False fused : bool, default = False
Whether to use a fused applying RoPE implementation. Whether to use a fused applying RoPE implementation.
cu_seqlens: torch.Tensor, default = None. cu_seqlens : torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'. dtype torch.int32. Only valid when `tensor_format` is 'thd'.
Should be `cu_seqlens_padded` when cp_size > 1. Should be `cu_seqlens_padded` when cp_size > 1.
cp_size: int, default = 1. cp_size : int, default = 1.
Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True. Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
cp_rank: int, default = 0. cp_rank : int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
""" """
assert ( assert (
...@@ -492,32 +492,32 @@ def apply_fused_qkv_rotary_pos_emb( ...@@ -492,32 +492,32 @@ def apply_fused_qkv_rotary_pos_emb(
Parameters Parameters
---------- ----------
qkv: torch.Tensor qkv : torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which
rotary positional embedding will be applied. This tensor has q, k, v concatenated rotary positional embedding will be applied. This tensor has q, k, v concatenated
along the last dimension. along the last dimension.
q_freqs: torch.Tensor q_freqs : torch.Tensor
Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float', Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`. with `s2 >= s` and `d2 <= d`.
k_freqs: torch.Tensor k_freqs : torch.Tensor
Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float', Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`. with `s2 >= s` and `d2 <= d`.
qkv_split_arg_list: List[int] qkv_split_arg_list : List[int]
List of integers that specify the split of the qkv tensor. The list should have 3 elements, List of integers that specify the split of the qkv tensor. The list should have 3 elements,
the first element is the number of elements in the q tensor, the second element is the number the first element is the number of elements in the q tensor, the second element is the number
of elements in the k tensor, and the third element is the number of elements in the v tensor. of elements in the k tensor, and the third element is the number of elements in the v tensor.
The sum of the elements in the list should be equal to the last dimension of the qkv tensor. The sum of the elements in the list should be equal to the last dimension of the qkv tensor.
start_positions: torch.Tensor, default = None. start_positions : torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset. `start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' tensor_format : {'sbhd', 'bshd'}, default = 'sbhd'
is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is
of shape `[seq, bs, ...]`. of shape `[seq, bs, ...]`.
interleaved: bool, default = False interleaved : bool, default = False
Whether to use interleaved rotary position embedding. Whether to use interleaved rotary position embedding.
cp_size: int, default = 1. cp_size : int, default = 1.
Context parallel world size. Context parallel world size.
cp_rank: int, default = 0. cp_rank : int, default = 0.
Context parallel rank. Context parallel rank.
""" """
......
...@@ -146,89 +146,89 @@ def fused_attn_fwd( ...@@ -146,89 +146,89 @@ def fused_attn_fwd(
Parameters Parameters
---------- ----------
is_training: bool is_training : bool
if True, runs training and produces auxiliary tensors aux_ctx_tensors if True, runs training and produces auxiliary tensors aux_ctx_tensors
for the backward; if False, runs inference and doesn't produce aux_ctx_tensors for the backward; if False, runs inference and doesn't produce aux_ctx_tensors
max_seqlen_q: int max_seqlen_q : int
max sequence length for Q, used for padding; max sequence length for Q, used for padding;
may be larger than max(seqlens_q), may be larger than max(seqlens_q),
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_kv: int max_seqlen_kv : int
max sequence length for K and V, used for padding; max sequence length for K and V, used for padding;
may be larger than max(seqlens_kv), may be larger than max(seqlens_kv),
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
cu_seqlens_q: torch.Tensor cu_seqlens_q : torch.Tensor
cumulative sequence lengths for Q; shape [batch_size + 1] cumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor cu_seqlens_kv : torch.Tensor
cumulative sequence lengths for K and V; shape [batch_size + 1] cumulative sequence lengths for K and V; shape [batch_size + 1]
q: torch.Tensor q : torch.Tensor
input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details) input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details)
k: torch.Tensor k : torch.Tensor
input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details) input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details)
v: torch.Tensor v : torch.Tensor
input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details) input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details)
fake_dtype: tex.DType fake_dtype : tex.DType
data type of Q, K and V - in case of high precision, fake dtype in case of FP8; data type of Q, K and V - in case of high precision, fake dtype in case of FP8;
in torch.dtype in torch.dtype
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend : tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
attn_bias: torch.Tensor, default = None attn_bias : torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
cu_seqlens_q_padded: torch.Tensor, default = None cu_seqlens_q_padded : torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for Q; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None cu_seqlens_kv_padded : torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
page_table_k: torch.Tensor, default = None page_table_k : torch.Tensor, default = None
page table for K cache; shape [batch_size, max_pages_per_seq_k] page table for K cache; shape [batch_size, max_pages_per_seq_k]
page_table_v: torch.Tensor, default = None page_table_v : torch.Tensor, default = None
page table for V cache; shape [batch_size, max_pages_per_seq_v] page table for V cache; shape [batch_size, max_pages_per_seq_v]
s_quantizer: Quantizer, default = None s_quantizer : Quantizer, default = None
Quantizer object for the intermediate value S. Quantizer object for the intermediate value S.
o_quantizer: Quantizer, default = None o_quantizer : Quantizer, default = None
Quantizer object for the output of the attention. Quantizer object for the output of the attention.
attn_scale: float, default = None attn_scale : float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM; if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim_qk) as the default if None, use 1.0/sqrt(head_dim_qk) as the default
dropout: float, default = 0.0 dropout : float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output; dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False dropout must be 0.0 if is_training is False
fast_zero_fill: bool, default = True fast_zero_fill : bool, default = True
if True, initializes the output tensor O to zero using the fast filling method; if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "sbh3d" qkv_layout : str, default = "sbh3d"
layout of Q, K and V; layout of Q, K and V;
{"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"}
attn_bias_type: str, default = "no_bias" attn_bias_type : str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type : str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla" softmax_type : str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"} type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1) window_size : Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. window and causal mask specifically.
rng_gen: torch.Generator, default = None rng_gen : torch.Generator, default = None
random number generator; random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
softmax_offset: torch.Tensor, default = None softmax_offset : torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1]. softmax offset tensor of shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details. See softmax_type in DotProductAttention for details.
return_max_logit: bool, default = False return_max_logit : bool, default = False
whether to return the maximum attention score whether to return the maximum attention score
cuda_graph: bool, default = False cuda_graph : bool, default = False
whether or not cuda graph capture is enabled. whether or not cuda graph capture is enabled.
Returns Returns
---------- ----------
o: torch.Tensor o : torch.Tensor
output tensor O, of the attention calculation; same data type as Q, K and V; output tensor O, of the attention calculation; same data type as Q, K and V;
same shape as Q same shape as Q
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors : List[torch.Tensor]
auxiliary output tensors used for the backward; auxiliary output tensors used for the backward;
if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state] if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state]
if is_training is False, aux_ctx_tensors = None if is_training is False, aux_ctx_tensors = None
...@@ -252,7 +252,7 @@ def fused_attn_fwd( ...@@ -252,7 +252,7 @@ def fused_attn_fwd(
rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen
state of the random number generator; state of the random number generator;
[seed, offset], dtype uint64 [seed, offset], dtype uint64
max_logit: if return_max_logit = True, shape [h] and same data type as O; otherwise None max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None
""" """
if attn_scale is None: if attn_scale is None:
...@@ -377,89 +377,89 @@ def fused_attn_bwd( ...@@ -377,89 +377,89 @@ def fused_attn_bwd(
Parameters Parameters
---------- ----------
max_seqlen_q: int max_seqlen_q : int
max sequence length for Q, used for padding; may be larger than max(seqlens_q), max sequence length for Q, used for padding; may be larger than max(seqlens_q),
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_kv: int max_seqlen_kv : int
max sequence length for K and V, used for padding; max sequence length for K and V, used for padding;
may be larger than max(seqlens_kv), may be larger than max(seqlens_kv),
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
cu_seqlens_q: torch.Tensor cu_seqlens_q : torch.Tensor
cumulative sequence lengths for Q; shape [batch_size + 1] cumulative sequence lengths for Q; shape [batch_size + 1]
cu_seqlens_kv: torch.Tensor cu_seqlens_kv : torch.Tensor
cumulative sequence lengths for K and V; shape [batch_size + 1] cumulative sequence lengths for K and V; shape [batch_size + 1]
q: torch.Tensor q : torch.Tensor
input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details) input tensor Q; shape sbhd, bshd or thd (see `qkv_layout` for details)
k: torch.Tensor k : torch.Tensor
input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details) input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details)
v: torch.Tensor v : torch.Tensor
input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details) input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details)
o: torch.Tensor o : torch.Tensor
input tensor O (output of forward); same data type as Q, K and V; input tensor O (output of forward); same data type as Q, K and V;
same shape as Q same shape as Q
d_o: torch.Tensor d_o : torch.Tensor
input tensor dO (gradient of O); same data type as Q, K and V; input tensor dO (gradient of O); same data type as Q, K and V;
same shape as Q same shape as Q
fake_dtype: tex.DType fake_dtype : tex.DType
data type of Q, K and V - in case of high precision, fake dtype in case of FP8; data type of Q, K and V - in case of high precision, fake dtype in case of FP8;
in torch.dtype in torch.dtype
dqkv_dtype: tex.DType dqkv_dtype : tex.DType
data type of dQ, dK and dV; in tex.DType, not torch.dtype data type of dQ, dK and dV; in tex.DType, not torch.dtype
aux_ctx_tensors: List[torch.Tensor] aux_ctx_tensors : List[torch.Tensor]
auxiliary output tensors of the forward pass when its is_training is True, auxiliary output tensors of the forward pass when its is_training is True,
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend : tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
cu_seqlens_q_padded: torch.Tensor, default = None cu_seqlens_q_padded : torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for Q; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None cu_seqlens_kv_padded : torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
s_quantizer: Quantizer, default = None s_quantizer : Quantizer, default = None
Quantizer object for the intermediate value S. Quantizer object for the intermediate value S.
dp_quantizer: Quantizer, default = None dp_quantizer : Quantizer, default = None
Quantizer object for the intermediate value dP. Quantizer object for the intermediate value dP.
dqkv_quantizer: Quantizer, default = None dqkv_quantizer : Quantizer, default = None
Quantizer object for the output values of the fused_attn_bwd. Quantizer object for the output values of the fused_attn_bwd.
dropout: float, default = 0.0 dropout : float, default = 0.0
dropout probability, 0.0 means no dropout, 1.0 means no output; dropout probability, 0.0 means no dropout, 1.0 means no output;
dropout must be 0.0 if is_training is False dropout must be 0.0 if is_training is False
fast_zero_fill: bool, default = True fast_zero_fill : bool, default = True
if True, initializes the output tensor O to zero using the fast filling method; if True, initializes the output tensor O to zero using the fast filling method;
if False, uses PyTorch's .fill_() method if False, uses PyTorch's .fill_() method
qkv_layout: str, default = "sbh3d" qkv_layout : str, default = "sbh3d"
layout of Q, K and V; layout of Q, K and V;
{"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd",
"bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd",
"t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"}
attn_bias_type: str, default = "no_bias" attn_bias_type : str, default = "no_bias"
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type : str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla" softmax_type : str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"} type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1) window_size : Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. window and causal mask specifically.
deterministic: bool, default = False deterministic : bool, default = False
whether to execute the backward pass with deterministic behaviours. whether to execute the backward pass with deterministic behaviours.
cuda_graph: bool, default = False cuda_graph : bool, default = False
whether or not cuda graph capture is enabled. whether or not cuda graph capture is enabled.
Returns Returns
---------- ----------
d_q: torch.Tensor d_q : torch.Tensor
gradient tensor of Q; same data type and shape as Q gradient tensor of Q; same data type and shape as Q
d_k: torch.Tensor d_k : torch.Tensor
gradient tensor of K; same data type and shape as K gradient tensor of K; same data type and shape as K
d_v: torch.Tensor d_v : torch.Tensor
gradient tensor of V; same data type and shape as V gradient tensor of V; same data type and shape as V
d_bias: torch.Tensor, optional d_bias : torch.Tensor, optional
gradient tensor of Bias when attn_bias_type is "pre_scale_bias" gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias or "post_scale_bias"; same data type and shape as Bias
d_softmax_offset: torch.Tensor, optional d_softmax_offset : torch.Tensor, optional
gradient tensor of softmax offset in shape [1, h_q, 1, 1]. gradient tensor of softmax offset of shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details. See softmax_type in DotProductAttention for details.
""" """
if attn_scale is None: if attn_scale is None:
......
...@@ -657,60 +657,64 @@ def get_cpu_offload_context( ...@@ -657,60 +657,64 @@ def get_cpu_offload_context(
Parameters Parameters
---------- ----------
enabled: bool, default = `False` enabled : bool, default = False
When set to True, CPU Offloading functionality is enabled. When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1 num_layers : int, default = 1
Determines the number of layers Determines the number of layers
you want to offload activations/weights for. you want to offload activations/weights for.
model_layers: int, default = 1 model_layers : int, default = 1
Number of layers in the model that will be used under this context. Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True` offload_activations : bool, default = True
Deprecated. Deprecated.
offload_weights: bool, default = `True` offload_weights : bool, default = True
Deprecated. Deprecated.
double_buffering: bool, default = `False` double_buffering : bool, default = False
Deprecated. Deprecated.
retain_pinned_cpu_buffers: bool, default = `False` retain_pinned_cpu_buffers : bool, default = False
If True, the pinned CPU buffers are retained after offloading If True, the pinned CPU buffers are retained after offloading
and reused for the next iteration. It is useful for cuda graphs capture. and reused for the next iteration. It is useful for cuda graphs capture.
manual_synchronization: bool, default = `False` manual_synchronization : bool, default = False
If True, the synchronization is done manually by the user. If True, the synchronization is done manually by the user.
Additional argument manual_controller is returned. See more in manual control section. Additional argument manual_controller is returned. See more in manual control section.
offload_stream: torch.cuda.Stream, default = `None` offload_stream : torch.cuda.Stream, default = None
If provided, the offload stream is used for offloading and reloading. If provided, the offload stream is used for offloading and reloading.
Otherwise, a new stream is allocated internally. It can be other than None Otherwise, a new stream is allocated internally. It can be other than None
only if manual_synchronization is True. only if manual_synchronization is True.
Manual synchronization Notes
---------- -----
**Manual synchronization:**
By default, layers are offloaded/reloaded asynchronously By default, layers are offloaded/reloaded asynchronously
with respect to the current forward/backward stream with predefined synchronization, with respect to the current forward/backward stream with predefined synchronization,
to ensure that activation memory usage is equal to to ensure that activation memory usage is equal to
`(num_layers - num_offloaded_layers) * T`, where `T` is the memory footprint of a layer. ``(num_layers - num_offloaded_layers) * T``, where ``T`` is the memory footprint of a layer.
For more control over the offloading and reloading process, you can set `manual_synchronization=True`. For more control over the offloading and reloading process, you can set ``manual_synchronization=True``.
In this case, an additional argument, `manual_controller`, is returned. In this case, an additional argument, ``manual_controller``, is returned.
The `manual_controller` provides the following methods: The ``manual_controller`` provides the following methods:
- `start_offload_layer(layer_id: int)` - ``start_offload_layer(layer_id: int)``
- `release_activation_forward_gpu_memory(layer_id: int)` - ``release_activation_forward_gpu_memory(layer_id: int)``
- `start_reload_layer(layer_id: int)` - ``start_reload_layer(layer_id: int)``
If none of these methods are invoked for a given layer, that layer will not be offloaded or reloaded. If none of these methods are invoked for a given layer, that layer will not be offloaded or reloaded.
If `start_offload_layer()` is called for a layer, offload copies for that layer begin asynchronously on the offload stream. If ``start_offload_layer()`` is called for a layer, offload copies for that layer begin asynchronously on the offload stream.
Since GPU activations must be kept in memory until the copy is finished, pointers to all activations are stored. Since GPU activations must be kept in memory until the copy is finished, pointers to all activations are stored.
To release this memory, you need to call `release_activation_forward_gpu_memory(layer_id)`. To release this memory, you need to call ``release_activation_forward_gpu_memory(layer_id)``.
This method makes the current stream wait for an event recorded on the offload stream after all tensors from the layer have been offloaded. This method makes the current stream wait for an event recorded on the offload stream after all tensors from the layer have been offloaded.
The `start_reload_layer()` method is used to start reloading a layer. The ``start_reload_layer()`` method is used to start reloading a layer.
Each tensor reload is awaited to finish before `tensor_pop()` for that tensor is called on the current stream. Each tensor reload is awaited to finish before ``tensor_pop()`` for that tensor is called on the current stream.
You can provide an `offload_stream` to be used for offload and reload operations. You can provide an ``offload_stream`` to be used for offload and reload operations.
This allows for more detailed synchronization, such as delaying the start of offloading. This allows for more detailed synchronization, such as delaying the start of offloading.
Example: **Example:**
.. code-block:: python .. code-block:: python
offload_stream = torch.cuda.Stream() offload_stream = torch.cuda.Stream()
cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context( cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context(
enabled=True, model_layers=num_layers, manual_synchronization=True, offload_stream=offload_stream) enabled=True, model_layers=num_layers, manual_synchronization=True, offload_stream=offload_stream)
...@@ -732,10 +736,10 @@ def get_cpu_offload_context( ...@@ -732,10 +736,10 @@ def get_cpu_offload_context(
for i in range(num_layers): for i in range(num_layers):
out[i].sum().backward() out[i].sum().backward()
V1 code path **V1 code path:**
----------
If you want to use the v1 code path for offloading, If you want to use the v1 code path for offloading,
please set the environment variable NVTE_CPU_OFFLOAD_V1 to 1. please set the environment variable ``NVTE_CPU_OFFLOAD_V1`` to 1.
""" """
if NVTE_CPU_OFFLOAD_V1: if NVTE_CPU_OFFLOAD_V1:
......
...@@ -685,18 +685,18 @@ def get_cpu_offload_context( ...@@ -685,18 +685,18 @@ def get_cpu_offload_context(
Parameters Parameters
---------- ----------
enabled: bool, default = `False` enabled : bool, default = `False`
When set to True, CPU Offloading functionality is enabled. When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1 num_layers : int, default = 1
Determines the number of transformer layers Determines the number of transformer layers
you want to offload activations/weights for. you want to offload activations/weights for.
model_layers: int, default = 1 model_layers : int, default = 1
Number of layers in the model that will be used under this context. Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True` offload_activations : bool, default = `True`
When set to `True`, offloads the activations for the TE layer. When set to `True`, offloads the activations for the TE layer.
offload_weights: bool, default = `True` offload_weights : bool, default = `True`
When set to `True`, offloads the weights for the TE layer. When set to `True`, offloads the weights for the TE layer.
double_buffering: bool, default = `False` double_buffering : bool, default = `False`
When set to `True`, uses double buffering for offloading. When set to `True`, uses double buffering for offloading.
""" """
......
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
"""Cross Entropy Loss API""" """Cross Entropy Loss API"""
from typing import Optional
import warnings
import torch import torch
import transformer_engine.pytorch.triton.cross_entropy as triton_cross_entropy import transformer_engine.pytorch.triton.cross_entropy as triton_cross_entropy
...@@ -23,7 +26,7 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -23,7 +26,7 @@ class CrossEntropyFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
_input, inp,
target, target,
label_smoothing=0.0, label_smoothing=0.0,
reduce_loss=False, reduce_loss=False,
...@@ -37,7 +40,7 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -37,7 +40,7 @@ class CrossEntropyFunction(torch.autograd.Function):
Parameters: Parameters:
ctx : The context object. ctx : The context object.
_input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size. inp (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size.
target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1]. target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1].
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension.
...@@ -47,8 +50,8 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -47,8 +50,8 @@ class CrossEntropyFunction(torch.autograd.Function):
Returns: Returns:
tensor: The computed loss. tensor: The computed loss.
""" """
loss, _input = triton_cross_entropy.cross_entropy_forward( loss, inp = triton_cross_entropy.cross_entropy_forward(
_input, inp,
target, target,
label_smoothing, label_smoothing,
reduce_loss, reduce_loss,
...@@ -56,7 +59,7 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -56,7 +59,7 @@ class CrossEntropyFunction(torch.autograd.Function):
ignore_idx, ignore_idx,
) )
ctx.save_for_backward(_input.detach()) ctx.save_for_backward(inp.detach())
ctx.is_cg_capturable = is_cg_capturable ctx.is_cg_capturable = is_cg_capturable
return loss return loss
...@@ -72,12 +75,10 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -72,12 +75,10 @@ class CrossEntropyFunction(torch.autograd.Function):
Returns: Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
""" """
(_input,) = ctx.saved_tensors (inp,) = ctx.saved_tensors
_input = triton_cross_entropy.cross_entropy_backward( inp = triton_cross_entropy.cross_entropy_backward(inp, grad_output, ctx.is_cg_capturable)
_input, grad_output, ctx.is_cg_capturable
)
return ( return (
_input, inp,
None, None,
None, None,
None, None,
...@@ -87,4 +88,65 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -87,4 +88,65 @@ class CrossEntropyFunction(torch.autograd.Function):
) )
parallel_cross_entropy = CrossEntropyFunction.apply def parallel_cross_entropy(
inp: torch.Tensor,
target: torch.Tensor,
label_smoothing: float = 0.0,
reduce_loss: bool = False,
dist_process_group: Optional[torch.distributed.ProcessGroup] = None,
ignore_idx: int = -100,
is_cg_capturable: bool = False,
*,
_input: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Cross Entropy loss with optional distributed reduction.
The input tensor can be in BF16/FP32, the loss and gradient calculation happens in
FP32 only. The returned loss is always in FP32, the input gradients are upcasted
to the datatype of the input.
If ``dist_process_group`` is passed for distributed loss calculation, the input to each
distributed rank should be ``(*, V/world_size)``. Note that each of the ranks should
get equal shards along the V dimension.
Parameters
----------
inp : torch.Tensor
The input tensor of shape ``(B, SQ, V)`` or ``(SQ, B, V)`` where B is batch size,
SQ is sequence length, V is vocab size.
target : torch.Tensor
The target tensor of shape ``(B, SQ)`` or ``(SQ, B)`` where each value is in ``[0, V-1]``.
label_smoothing : float, default = 0.0
The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduce_loss : bool, default = False
If True, returns the averaged loss across the B*SQ dimension.
dist_process_group : torch.distributed.ProcessGroup, default = None
The distributed process group the loss computation is split across, None if on 1 device.
ignore_idx : int, default = -100
The index for which loss and gradients are made to zero.
is_cg_capturable : bool, default = False
Whether the operation is CUDA graph capturable.
Returns
-------
torch.Tensor
The computed loss.
"""
# Handle backward compatibility with _input parameter
if _input is not None:
warnings.warn(
"The '_input' parameter is deprecated. Please use 'inp' instead.",
FutureWarning,
)
inp = _input
return CrossEntropyFunction.apply(
inp,
target,
label_smoothing,
reduce_loss,
dist_process_group,
ignore_idx,
is_cg_capturable,
)
...@@ -30,7 +30,7 @@ except ImportError: ...@@ -30,7 +30,7 @@ except ImportError:
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv
from . import torch_version from .torch_version import torch_version
from .utils import ( from .utils import (
is_non_tn_fp8_gemm_supported, is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data, safely_set_viewless_tensor_data,
...@@ -642,18 +642,18 @@ def checkpoint( ...@@ -642,18 +642,18 @@ def checkpoint(
Parameters Parameters
---------- ----------
function: Callable function : Callable
pytorch module used to run the forward and backward passes using pytorch module used to run the forward and backward passes using
the specified :attr:`args` and :attr:`kwargs`. the specified :attr:`args` and :attr:`kwargs`.
distribute_saved_activations: bool, default = False distribute_saved_activations : bool, default = False
if set to `True` and `use_reentrant=True`, first tensor argument is distributed if set to ``True`` and ``use_reentrant=True``, first tensor argument is distributed
across the specified tensor parallel group (`tp_group`) before saving it for the across the specified tensor parallel group (``tp_group``) before saving it for the
backward pass. This has no effect when `use_reentrant=False`. backward pass. This has no effect when ``use_reentrant=False``.
get_rng_state_tracker: `Callable`, default = None get_rng_state_tracker : Callable, default = None
python callable which returns an instance of :func:`CudaRNGStatesTracker`. python callable which returns an instance of :class:`CudaRNGStatesTracker`.
tp_group : ProcessGroup, default = None tp_group : ProcessGroup, default = None
tensor parallel process group. Used only when `distribute_saved_activations=True` tensor parallel process group. Used only when ``distribute_saved_activations=True``
and `use_reentrant=True`. If `None`, it falls back to the default group. and ``use_reentrant=True``. If ``None``, it falls back to the default group.
use_reentrant : bool, default = True use_reentrant : bool, default = True
perform checkpointing in reentrant mode. perform checkpointing in reentrant mode.
args : tuple args : tuple
...@@ -778,8 +778,8 @@ class CudaRNGStatesTracker: ...@@ -778,8 +778,8 @@ class CudaRNGStatesTracker:
For model parallelism, multiple RNG states need to simultaneously exist in order For model parallelism, multiple RNG states need to simultaneously exist in order
to execute operations in or out of the model parallel region. This class keeps to execute operations in or out of the model parallel region. This class keeps
track of the various RNG states and provides utility methods to maintain them and track of the various RNG states and provides utility methods to maintain them and
execute parts of the model under a given RNG setting. Using the `add` method, a execute parts of the model under a given RNG setting. Using the :meth:`add` method, a
cuda rng state is initialized based on the input `seed` and is assigned to `name`. cuda rng state is initialized based on the input ``seed`` and is assigned to ``name``.
Later, by forking the rng state, we can perform operations and return to our starting Later, by forking the rng state, we can perform operations and return to our starting
cuda state. cuda state.
""" """
...@@ -812,7 +812,9 @@ class CudaRNGStatesTracker: ...@@ -812,7 +812,9 @@ class CudaRNGStatesTracker:
Set the rng states. For efficiency purposes, we do not Set the rng states. For efficiency purposes, we do not
check the size of seed for compatibility. check the size of seed for compatibility.
states: Dict[str, torch.Tensor] Parameters
----------
states : Dict[str, torch.Tensor]
A mapping from string names to RNG states. A mapping from string names to RNG states.
""" """
self.states_ = states self.states_ = states
...@@ -821,9 +823,11 @@ class CudaRNGStatesTracker: ...@@ -821,9 +823,11 @@ class CudaRNGStatesTracker:
""" """
Adds a new RNG state. Adds a new RNG state.
name: str Parameters
----------
name : str
string identifier for the RNG state. string identifier for the RNG state.
seed: int seed : int
PyTorch seed for the RNG state. PyTorch seed for the RNG state.
""" """
# Check seed is not already used. # Check seed is not already used.
...@@ -857,7 +861,9 @@ class CudaRNGStatesTracker: ...@@ -857,7 +861,9 @@ class CudaRNGStatesTracker:
Fork the cuda rng state, perform operations, and exit with Fork the cuda rng state, perform operations, and exit with
the original state. the original state.
name: str Parameters
----------
name : str
string identifier for the RNG state. string identifier for the RNG state.
""" """
# Check if we have added the state # Check if we have added the state
...@@ -2003,7 +2009,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ...@@ -2003,7 +2009,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
Parameters Parameters
---------- ----------
fsdp_root: torch.nn.Module fsdp_root : torch.nn.Module
FSDP-wrapped root module that may contain FSDP-wrapped TE modules. FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
""" """
assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped." assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped."
......
...@@ -28,7 +28,7 @@ def onnx_export(enabled: bool = False) -> Generator[None, None, None]: ...@@ -28,7 +28,7 @@ def onnx_export(enabled: bool = False) -> Generator[None, None, None]:
Parameters Parameters
---------- ----------
enabled: bool, default = `False` enabled : bool, default = False
whether or not to enable export whether or not to enable export
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment