Unverified Commit 4d1f92df authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[paddle] add documentation (#489)



* paddle documentation
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* minor fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6b311da2
......@@ -10,3 +10,4 @@ Framework-specific API
pytorch
jax
paddle
..
Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
paddle
======
.. autoapiclass:: transformer_engine.paddle.Linear(in_features, out_features, **kwargs)
:members: forward
.. autoapiclass:: transformer_engine.paddle.LayerNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.paddle.LayerNormLinear(in_features, out_features, eps=1e-5, **kwargs)
:members: forward
.. autoapiclass:: transformer_engine.paddle.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, **kwargs)
:members: forward
.. autoapiclass:: transformer_engine.paddle.FusedScaleMaskSoftmax(attn_mask_type, mask_func, **kwargs)
:members: forward
.. autoapiclass:: transformer_engine.paddle.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
:members: forward
.. autoapiclass:: transformer_engine.paddle.MultiHeadAttention(hidden_size, num_attention_heads, **kwargs)
:members: forward
.. autoapiclass:: transformer_engine.paddle.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward
.. autoapifunction:: transformer_engine.paddle.fp8_autocast
.. autoapifunction:: transformer_engine.paddle.recompute
......@@ -15,6 +15,10 @@ from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type
from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer, FP8RecomputeBuffer
__all__ = ['fp8_autocast']
# FP8 support
_is_fp8_available = None
_reason_for_no_fp8 = ""
......@@ -166,6 +170,40 @@ def fp8_autocast(
) -> None:
"""
Context manager for FP8 usage.
.. code-block:: python
with fp8_autocast(enabled=True):
out = model(inp)
.. note::
Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors
with shapes where both dimensions are divisible by 16. In terms of the input to the full
Transformer network, this typically requires padding sequence length to be multiple of 16.
.. note::
When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once
inside a single `fp8_autocast` region. This is unsupported behavior because the amax
reduction is handled during the exit of the `fp8_autocast` context. Calling the same
module more than once inside an `fp8_autocast` region overrides the amax tensors
before reduction can occur.
Parameters
----------
enabled: bool, default = `False`
whether or not to enable fp8
calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale
data of fp8 tensors even when executing without fp8 enabled. This is
useful for saving an inference ready fp8 checkpoint while training
using a higher precision.
fp8_recipe: recipe.DelayedScaling, default = `None`
recipe used for FP8 training.
fp8_group: paddle.distributed.collective.Group, default = `None`
distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step.
"""
try:
_global_fp8_state.enter(enabled, calibrating, fp8_recipe, fp8_group)
......
......@@ -29,6 +29,9 @@ from ..utils import attention_mask_func, divide
from ..recompute import recompute
__all__ = ["DotProductAttention", "MultiHeadAttention"]
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
"""Function for FusedAttention with packed QKV input"""
......@@ -129,7 +132,7 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
class DotProductAttention(paddle.nn.Layer):
"""Dot Product Attention Layer
"""
Allows the model to jointly attend to information from different
representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
......@@ -151,7 +154,6 @@ class DotProductAttention(paddle.nn.Layer):
type of attention operation.
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
backend to use for attention operation.
"""
def __init__(self,
......@@ -224,7 +226,7 @@ class DotProductAttention(paddle.nn.Layer):
only support no_bias type currently, {`no_bias`}
core_attention_bias: Optional[paddle.Tensor], default = `None`
Bias tensor for Q * K.T
set_zero: bool, defautl = `True`
set_zero: bool, default = `True`
Whether to use the fast path to set output tensors to 0 or not.
"""
......@@ -358,7 +360,9 @@ class DotProductAttention(paddle.nn.Layer):
class MultiHeadAttention(paddle.nn.Layer):
"""Attention w/ QKV and Proj Gemms
"""
Multi-head Attention (MHA), including Query,
Key, Value and Output projection.
Parameters
----------
......@@ -387,7 +391,8 @@ class MultiHeadAttention(paddle.nn.Layer):
zero_centered_gamma: bool, default = `False`
whether to zero initialize the gamma of the layernorm operation.
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
backend to use for attention operation.
backend to use for attention operation. If set to 'paddle', a framework
only no-FP8 path is executed with limited optimization.
Parallelism parameters
----------------------
......@@ -542,7 +547,6 @@ class MultiHeadAttention(paddle.nn.Layer):
"""
MultiHeadAttention Layer.
Parameters
----------
hidden_states : paddle.Tensor
......@@ -555,7 +559,7 @@ class MultiHeadAttention(paddle.nn.Layer):
only support no_bias type currently, {`no_bias`}
core_attention_bias: Optional[paddle.Tensor], default = `None`
Bias tensor for Q * K.T
set_zero: bool, defautl = `True`
set_zero: bool, default = `True`
Whether to use the fast path to set output tensors to 0 or not.
recompute_core_attention: bool, default = `False`
If true, forward activations for core attention are recomputed
......
......@@ -63,7 +63,33 @@ class _LayerNorm(paddle.autograd.PyLayer):
class LayerNorm(paddle.nn.Layer):
r"""
Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
size :attr:`hidden_size`
Parameters
----------
hidden_size : int
size of each input sample.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
weight_attr: Union[paddle.ParamAttr, None], default = None
optional `paddle.ParamAttr` for weight.
bias_attr: Union[paddle.ParamAttr, None, bool], default = None
optional `paddle.ParamAttr` for bias.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
backend to use for softmax operation.
"""
def __init__(
......
......@@ -40,7 +40,7 @@ from ..utils import (
saved_tensor_allow_none,
)
__all__ = ["LayerNormLinear", "_layernorm_fwd_fp8_cast", "_layernorm_bwd"]
__all__ = ["LayerNormLinear"]
def _layernorm_fwd_fp8_cast(
......@@ -331,6 +331,42 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
class LayerNormLinear(TransformerEngineBaseLayer):
r"""
Applies layer normalization followed by linear transformation to the incoming data.
Parameters
----------
in_features : int
size of each input sample.
out_features : int
size of each output sample.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
weight_attr: Union[paddle.ParamAttr, None], default = None
optional `paddle.ParamAttr` for weight.
bias_attr: Union[paddle.ParamAttr, None, bool], default = None
optional `paddle.ParamAttr` for bias.
return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward
together with the output of the linear transformation.
Example use case: residual connection for transformer module is
taken post layernorm.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
if set to 'paddle', a framework only no-FP8 path is executed with limited optimization.
Parallelism parameters
----------------------
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
"""
def __init__(
......@@ -503,7 +539,14 @@ class LayerNormLinear(TransformerEngineBaseLayer):
return out
def forward(self, *args, **kwargs):
"""forward"""
"""
Apply layer normalization to the input followed by a linear transformation.
Parameters
----------
inp : torch.Tensor
Input tensor.
"""
if self.backend == 'transformer_engine':
return self._te_forward(*args, **kwargs)
if self.backend == 'paddle':
......
......@@ -39,6 +39,7 @@ from ..utils import (
saved_tensor_allow_none,
)
__all__ = ["LayerNormMLP"]
......@@ -549,7 +550,47 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
class LayerNormMLP(TransformerEngineBaseLayer):
r"""
Applies layer normalization followed by linear transformation to the incoming data.
Applies layer normalization on the input followed by the MLP module, consisting of
2 successive linear transformations, separated by the GeLU activation.
Parameters
----------
hidden_size : int
size of each input sample.
ffn_hidden_size : int
intermediate size to which input samples are projected.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
weight_attr: Union[paddle.ParamAttr, None], default = None
optional `paddle.ParamAttr` for weight.
bias_attr: Union[paddle.ParamAttr, None, bool], default = None
optional `paddle.ParamAttr` for bias.
activation : str, default = 'gelu'
activation function used.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'.
return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward
together with the output of the linear transformation.
Example use case: residual connection for transformer module
is taken post layernorm.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
if set to 'paddle', a framework only no-FP8 path is executed with limited optimization.
Parallelism parameters
----------------------
set_parallel_mode : bool, default = `False`
if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row
Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
tp_group : paddle.distributed.collective.Group, default = `None`
tensor parallel process group.
"""
def __init__(
......@@ -753,7 +794,14 @@ class LayerNormMLP(TransformerEngineBaseLayer):
return out
def forward(self, *args, **kwargs):
"""forward"""
"""
Apply layer normalization to the input followed by a feedforward network (MLP Block).
Parameters
----------
inp : torch.Tensor
Input tensor.
"""
if self.backend == 'transformer_engine':
return self._te_forward(*args, **kwargs)
if self.backend == 'paddle':
......
......@@ -38,7 +38,7 @@ from ..utils import (
saved_tensor_allow_none,
)
__all__ = ["Linear", "_linear_fwd", "_linear_fwd_fp8", "_linear_bwd", "_linear_fwd_non_fp8"]
__all__ = ["Linear"]
def _linear_fwd_fp8(
......@@ -541,6 +541,29 @@ class _Linear(paddle.autograd.PyLayer):
class Linear(TransformerEngineBaseLayer):
"""
Applies a linear transformation to the incoming data :math:`y = xA^T + b`
Parameters
----------
in_features : int
size of each input sample.
out_features : int
size of each output sample.
weight_attr: Union[paddle.ParamAttr, None], default = None
optional `paddle.ParamAttr` for weight.
bias_attr: Union[paddle.ParamAttr, None, bool], default = None
optional `paddle.ParamAttr` for bias.
backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
if set to 'paddle', a framework only no-FP8 path is executed with limited optimization.
Parallelism parameters
----------------------
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
"""
def __init__(
......@@ -658,7 +681,14 @@ class Linear(TransformerEngineBaseLayer):
return out
def forward(self, *args, **kwargs):
"""forward"""
"""
Apply the linear transformation to the input.
Parameters
----------
inp : torch.Tensor
Input tensor.
"""
if self.backend == 'transformer_engine':
return self._te_forward(*args, **kwargs)
if self.backend == 'paddle':
......
......@@ -18,9 +18,14 @@ from transformer_engine.paddle.cpp_extensions import (
scaled_softmax_backward,
)
__all__ = ["FusedScaleMaskSoftmax"]
THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128
_default_causal_mask = {}
......@@ -112,12 +117,22 @@ class ScaledSoftmax(paddle.autograd.PyLayer):
class FusedScaleMaskSoftmax(paddle.nn.Layer):
"""
fused operation: scaling + mask + softmax
Arguments:
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
Scaled and masked softmax module for paddle with fused optimizations.
Parameters
----------
attn_mask_type : str, default = `causal`
type of attention mask, can be 'causal', 'padding', or 'no_mask'.
mask_func : callable
custom callable for applying the mask to the softmax input.
`masked_input=mask_func(inp, mask)`.
softmax_in_fp32 : bool, default = True
perform softmax computation in fp32.
layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization
for numerical stability.
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
backend to use for operation.
"""
def __init__(
......
......@@ -8,9 +8,9 @@ from typing import Optional, Union
import paddle
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
from . import LayerNormMLP, LayerNorm, MultiHeadAttention
from ..constants import AttnMaskTypes, LayerTypes, dist_group_type
from ..distributed import get_tp_group_and_world_size, track_rng_state
from transformer_engine.paddle.layer import LayerNormMLP, LayerNorm, MultiHeadAttention
from transformer_engine.paddle.constants import AttnMaskTypes, LayerTypes, dist_group_type
from transformer_engine.paddle.distributed import get_tp_group_and_world_size, track_rng_state
class TransformerLayer(paddle.nn.Layer):
......@@ -33,6 +33,10 @@ class TransformerLayer(paddle.nn.Layer):
dropout probability for the dropout op after FC2 layer.
attention_dropout: float, default = 0.1
dropout probability for the dropout op during multi-head attention.
weight_attr: Union[paddle.ParamAttr, None], default = None
optional `paddle.ParamAttr` for weight.
bias_attr: Union[paddle.ParamAttr, None, bool], default = None
optional `paddle.ParamAttr` for bias.
self_attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
apply_residual_connection_post_layernorm : bool, default = `False`
......@@ -62,6 +66,8 @@ class TransformerLayer(paddle.nn.Layer):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
if set to 'paddle', a framework only no-FP8 path is executed with limited optimization.
Parallelism parameters
----------------------
......
......@@ -11,7 +11,9 @@ from paddle.distributed import fleet
from .constants import RecomputeFunctionNames
from .fp8 import get_global_fp8_state
__all__ = ['recompute', 'is_in_recompute_phase']
__all__ = ['recompute']
_DISABLE_RECOMPUTE = int(os.getenv("NVTE_DISABLE_RECOMPUTE", "0"))
......@@ -35,6 +37,16 @@ def recompute(function, *args, **kwargs):
"""
This is a wrapper of paddle.distributed.fleet.utils.recompute. It provides necessary
state information for fp8 layers.
Parameters
----------
function: Callable
paddle module used to run the forward and backward passes using
the specified :attr:`args` and :attr:`kwargs`.
args : tuple
tuple of torch tensors for inputs to :attr:`function`.
kwargs : dict
dictionary of string keys for keyword arguments to :attr:`function`.
"""
assert not _DISABLE_RECOMPUTE, "Recompute is disabled. " \
f"Got NVTE_DISABLE_RECOMPUTE={_DISABLE_RECOMPUTE}."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment