"vscode:/vscode.git/clone" did not exist on "8c4b2592fb953d1a8f880d42ebb1b28eaa94d0a6"
Unverified Commit ed1a3116 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

Adding documents to TE/JAX (#87)



* Updated TE/JAX docs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding TE/JAX docs' rst files
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Set DType as pybind11::module_local() to avoid generic_type errors.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Updating license and exporting more modules
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adopting autoapi and removing enum_tools.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix typo
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Make jax.rst be style consistent.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fixing doc statements as the suggestion from review.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fixing doc statements as the suggestion from code review.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update the description of Softmax
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Removed categories in catalog as PyTorch
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d74ee5b5
......@@ -9,3 +9,4 @@ Framework-specific API
.. toctree::
pytorch
jax
..
Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Jax
=======
.. autoapiclass:: transformer_engine.jax.MajorShardingType
.. autoapiclass:: transformer_engine.jax.ShardingType
.. autoapiclass:: transformer_engine.jax.TransformerLayerType
.. autoapiclass:: transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None)
.. autoapiclass:: transformer_engine.jax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.MultiHeadAttention(head_dim, num_heads, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs)
:members: __call__
.. autoapifunction:: transformer_engine.jax.extend_logical_axis_rules
.. autoapifunction:: transformer_engine.jax.fp8_autocast
.. autoapifunction:: transformer_engine.jax.update_collections
.. autoapifunction:: transformer_engine.jax.update_fp8_metas
\ No newline at end of file
......@@ -3,7 +3,9 @@
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
from .fp8 import fp8_autocast, update_collections, update_fp8_metas
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules
from .transformer import RelativePositionBiases, TransformerLayer, TransformerLayerType
from .sharding import ShardingResource
from .transformer import MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType
from .sharding import MajorShardingType, ShardingResource, ShardingType
......@@ -53,7 +53,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
pybind11::enum_<DType>(m, "DType")
pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte)
.value("kInt32", DType::kInt32)
.value("kFloat32", DType::kFloat32)
......
......@@ -302,16 +302,16 @@ def fp8_autocast(enabled: bool = False,
.. note::
We only support :attr:`margin`, :attr:`fp8_format`, :attr:`interval` and
:attr:`amax_history_len` in recipe.DelayedScaling currently. Other parameters
in recipe.DelayedScaling would be ignored, even is set.
in recipe.DelayedScaling would be ignored, even if set.
Parameters
----------
enabled: bool, default = False
whether or not to enable fp8
Whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None
recipe used for FP8 training.
sharding_resource: ShardingResource, defaule = None
specify the mesh axes for data and tensor parallelism to shard along.
Recipe used for FP8 training.
sharding_resource: ShardingResource, default = None
Specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then ShardingResource() would be created.
"""
if fp8_recipe is None:
......@@ -338,12 +338,11 @@ def fp8_autocast(enabled: bool = False,
# Function Wrappers
def update_collections(new: Collection, original: Collection) -> Collection:
def update_collections(new: Collection, original: Collection) -> FrozenDict:
r"""
A helper to update Flax's Collection. Collection is a union type of dict and
Flax's FrozenDict.
A helper to update Flax's Collection.
Collection = [dict, FrozenDict]
Collection = [dict, flax.core.frozen_dict.FrozenDict]
Parameters
----------
......@@ -364,14 +363,16 @@ def update_fp8_metas(state: Collection) -> Collection:
r"""
Calculate new fp8 scales and its inverse via the followed formula
`exp` = floor(log2(`fp8_max` / `amax`)) - `margin`
`sf` = round(power(2, abs(exp)))
`sf` = `sf` if `amax` > 0.0, else original_scale
`sf` = `sf` if isfinite(`amax`), else original_scale)
`updated_scale` = `1/sf` if exp < 0, else `sf`
`updated_scale_inv` = `1/updated_scale`
.. code-block:: python
exp = floor(log2(fp8_max / amax)) - margin
sf = round(power(2, abs(exp)))
sf = sf if amax > 0.0, else original_scale
sf = sf if isfinite(amax), else original_scale)
updated_scale = 1/sf if exp < 0, else sf
updated_scale_inv = 1/updated_scale
Collection = [dict, FrozenDict]
Collection = [dict, flax.core.frozen_dict.FrozenDict]
Parameters
----------
......
......@@ -98,19 +98,25 @@ def _combine_biases(*masks: List[Array]):
class Softmax(nn.Module):
r"""
Applies softmax over a mini-batch of inputs.
The inputs's shape should be [batch, heads, q_seqlen, k_seqlen].
The input's shape should be [batch, heads, q_seqlen, k_seqlen].
.. code-block:: python
shifted_input = input + bias
masked_scaled = (1 - mask)*(shifted_input * scale_factor)
softmax_mask = mask * -1e-10
output = softmax(masked_scaled + softmax_mask)
Parameters
----------
scale_factor : float, default = 1.0
scale the inputs along the last dimension before running softmax.
softmax_type : SoftmaxType, default = 'layernorm'
indicate the type of softmax.
Scalar for the input to softmax.
softmax_type : SoftmaxType, default = SoftmaxType.SCALED
Indicate the type of softmax.
Optimization parameters
-----------------------
sharding_type : ShardingType, default = ShardingType.SINGLE
indicate the sharding pattern.
Indicate the sharding pattern.
"""
scale_factor: float = 1.0
......@@ -158,7 +164,7 @@ class Softmax(nn.Module):
outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED,
self.sharding_type)
else:
outputs = jax_nn.softmax(logits)
outputs = jax_nn.softmax(logits * self.scale_factor)
return outputs
......@@ -193,30 +199,32 @@ class LayerNorm(nn.Module):
Parameters
----------
epsilon : float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability.
A value added to the denominator of layer normalization for numerical stability.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization.
Indicate the type of layer normalization.
scale_init : Initializer, default = flax.linen.initializers.ones
used for initializing scale factors :math:`\gamma`.
Used for initializing scale factors :math:`\gamma`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes : Tuple[str, ...], default = ('embed', )
the name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
bias_init : Initializer, default = flax.linen.initializers.zeros
used for initializing shift factors :math:`\beta`,
only works when :attr:`layernorm_type='layernorm'`.
Used for initializing shift factors :math:`\beta`,
only used when :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
bias_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
only works when :attr:`layernorm_type='layernorm'`.
only used when :attr:`layernorm_type='layernorm'`.
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
sharding_type : ShardingType, default = ShardingType.SINGLE
indicate the sharding pattern.
Indicate the sharding pattern.
"""
epsilon: float = 1e-6
layernorm_type: str = 'layernorm'
......@@ -316,33 +324,35 @@ class DenseGeneral(TransformerEngineBase):
Parameters
----------
features : Union[Iterable[int], int]
the hidden size of each output sample.
The hidden size of each output sample.
kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
used for initializing weights.
Used for initializing weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
kernel_axes : Tuple[str, ...], default = ()
the name of axes used to shard the weights with a corresponding mesh.
The name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False
indicate whether to enable bias shifting.
if set to False, the layer will not learn an additive bias.
Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias, only works when :attr:`use_bias=True`.
Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
bias_axes: Tuple[str, ...], default = ()
the name of axes used to shard bias with a corresponding mesh,
only works when :attr:`use_bias=True`.
The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`.
axis: Union[Iterable[int], int], default = -1
a integer of tuple with axes to apply the transformation on.
An integer tuple with axes to apply the transformation on.
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
sharding_type : ShardingType, default = ShardingType.SINGLE
indicate the sharding pattern.
Indicate the sharding pattern.
"""
features: Union[Iterable[int], int]
......@@ -426,56 +436,60 @@ class LayerNormDenseGeneral(TransformerEngineBase):
Parameters
----------
features : Union[Iterable[int], int]
the hidden size of each output sample.
The hidden size of each output sample.
enable_layernorm: bool, default = True
indicate whether to enable layer normalization before linear transformation.
Indicate whether to enable layer normalization before linear transformation.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization.
Indicate the type of layer normalization.
epsilon : float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability.
A value added to the denominator of layer normalization for numerical stability.
scale_init : Initializer, default = flax.linen.initializers.ones
used for initializing scale factors :math:`\gamma`.
Used for initializing scale factors :math:`\gamma`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes : Tuple[str, ...], default = ('embed', )
the name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only works when :attr:`enable_layernorm=True`.
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only used when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing shift factors :math:`\beta`,
only works when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
Used for initializing shift factors :math:`\beta`,
only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
only works when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
used for initializing weights.
Used for initializing weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
kernel_axes : Tuple[str, ...], default = ()
the name of axes used to shard the weights with a corresponding mesh.
The name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False
indicate whether to enable bias shifting.
if set to False, the layer will not learn an additive bias.
Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias, only works when :attr:`use_bias=True`.
Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
bias_axes: Tuple[str, ...], default = ()
the name of axes used to shard bias with a corresponding mesh,
only works when :attr:`use_bias=True`.
The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True
indicate whether to return the output of layer normalization.
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
axis: Union[Iterable[int], int], default = -1
a integer of tuple with axes to apply the transformation on.
An integer tuple with axes to apply the transformation on.
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
depth_scaling: float, default = None
the factor to scale the output from `DenseGeneral`. It should be a float
The factor to scale the output from `DenseGeneral`. It should be a float
value or None. When None is set, then no scaling is applied.
sharding_type : ShardingType, default = ShardingType.SINGLE
indicate the sharding pattern.
Indicate the sharding pattern.
"""
features: Union[Iterable[int], int]
......@@ -519,7 +533,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
Output tensors.
ln_outputs: jax.numpy.ndarray
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this woulb be None.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
ln_output = None
......@@ -617,67 +631,71 @@ class LayerNormMLP(TransformerEngineBase):
Parameters
----------
intermediate_dim: int, default = 2048
intermediate size to which input samples are projected.
Intermediate size to which input samples are projected.
enable_layernorm: bool, default = True
indicate whether to enable layer normalization before linear transformation.
Indicate whether to enable layer normalization before linear transformation.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization.
Indicate the type of layer normalization.
epsilon : float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability.
A value added to the denominator of layer normalization for numerical stability.
scale_init : Initializer, default = flax.linen.initializers.ones
used for initializing scale factors :math:`\gamma`.
Used for initializing scale factors :math:`\gamma`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes : Tuple[str, ...], default = ('embed', )
the name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only works when :attr:`enable_layernorm=True`.
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only used when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing shift factors :math:`\beta`,
only works when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
Used for initializing shift factors :math:`\beta`,
only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
only works when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
used for initializing weight of both linear transformations.
Used for initializing the weights of both linear transformations.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
the name of axes used to shard the weights with a corresponding mesh for
The name of axes used to shard the weights with a corresponding mesh for
the weight of the first linear transformations.
kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
the name of axes used to shard the weights with a corresponding mesh for
The name of axes used to shard the weights with a corresponding mesh for
the weight of the second linear transformations.
use_bias: bool, default = False
indicate whether to enable bias shifting.
if set to False, the layer will not learn an additive bias.
Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias, only works when :attr:`use_bias=True`.
Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
bias_axes_1: Tuple[str, ...], default = ('mlp',)
the name of axes used to shard bias with a corresponding mesh for
The name of axes used to shard bias with a corresponding mesh for
the weight of the first linear transformations.
only works when :attr:`use_bias=True`.
Only used when :attr:`use_bias=True`.
bias_axes_2: Tuple[str, ...], default = ('embed',)
the name of axes used to shard bias with a corresponding mesh for
The name of axes used to shard bias with a corresponding mesh for
the weight of the second linear transformations.
only works when :attr:`use_bias=True`.
Only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True
indicate whether to return the output of layer normalization.
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('relu',)
the sequence of activation functions to apply after the first linear transformation.
The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer.
intermediate_dropout_rate: float, default = 0.1
dropout probability for the dropout op after the :attr:`activations`.
Dropout probability for the dropout op after the :attr:`activations`.
axis: Union[Iterable[int], int], default = -1
a integer of tuple with axes to apply the transformation on.
An integer tuple with axes to apply the transformation on.
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
major_sharding_type : MajorShardingType, default = MajorShardingType.SINGLE
indicate the sharding pattern.
Indicate the sharding pattern.
"""
intermediate_dim: int = 2048
......@@ -726,7 +744,7 @@ class LayerNormMLP(TransformerEngineBase):
Output tensors.
ln_outputs: jax.numpy.ndarray
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this woulb be None.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
ln_output = None
......
......@@ -36,11 +36,11 @@ class ShardingResource:
Parameters
----------
dp_resource : str, default = None
axis name in Mesh used to shard batch along.
if it is None, then disabling data parallelism.
The axis name in Mesh used to shard batches along.
If it is None, then data parallelism is disabled.
tp_resource : str, default = None
axis name in Mesh used to split model tensor along.
if it is None, then disabling tensor parallelism.
The axis name in Mesh used to split the hidden dimensions along.
If it is None, then tensor parallelism is disabled.
"""
dp_resource: str = None
tp_resource: str = None
......@@ -71,12 +71,19 @@ def global_shard_resource() -> ShardingResource:
class MajorShardingType(Enum):
"""
r"""
The major sharding type to indicate sharding pattern.
`SINGLE` means single process training.
`DP` means data parallel traiing.
`TP` means tensor parallel traiing.
`DPTP` means data and tensor parallel traiing.
Values
----------
SINGLE:
Single process training.
DP:
Data parallel training.
TP:
Standard tensor parallel training.
DPTP:
Data and Standard tensor parallel training.
"""
SINGLE = 0
DP = 1
......@@ -87,12 +94,21 @@ class MajorShardingType(Enum):
class ShardingType(Enum):
"""
The sharding type to indicate sharding pattern.
`SINGLE` means no sharding.
`DP` means sharding along data parallelism.
`TP_COL` means sharding along column-split tensor parallelism.
`TP_ROW` means sharding along row-split tensor parallelism.
`DP_TP_COL` means sharding along data and column-split tensor parallelism.
`DP_TP_ROW` means sharding along data and row-split tensor parallelism.
Values
----------
SINGLE:
No sharding.
DP:
Sharding along data parallelism.
TP_COL:
Sharding along column-split tensor parallelism.
TP_ROW:
Sharding along row-split tensor parallelism.
DP_TP_COL:
Sharding along data and column-split tensor parallelism.
DP_TP_ROW:
Sharding along data and row-split tensor parallelism.
"""
SINGLE = (MajorShardingType.SINGLE, "single")
DP = (MajorShardingType.DP, "dp")
......
......@@ -41,7 +41,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[
def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
"""
Extend the given Flax logical axis rules with the pre-defined TransformerLayer's
Extend the given Flax logical axis rules with the predefined TransformerLayer's
logical axis rules.
.. note::
......@@ -185,53 +185,55 @@ class MultiHeadAttention(nn.Module):
Parameters
----------
head_dim : int
the hidden dimension of each attention heads.
The hidden dimension of each attention head.
num_heads : int
the number of attention heads
The number of attention heads
dropout_rate : float, default = 0.0
dropout probability for the dropout op during multi-head attention.
Dropout probability for the dropout op during multi-head attention.
dropout_rng_name: str, default = 'dropout'
the key in given RNGs via flax.linen.Module.apply that
for generate Dropout masks in the core attention.
The key in given RNGs via flax.linen.Module.apply that is used
to generate Dropout masks in the core attention.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization.
Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability.
A value added to the denominator of layer normalization for numerical stability.
kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
used for initializing weights of QKV and Output projection weights.
Used for initializing the QKV and Output projection weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
use_bias: bool, default = False
indicate whether to enable bias shifting for QKVO projections.
if set to False, the layer will not learn additive biases.
Indicate whether or not to enable bias shifting for QKVO projections.
If set to False, the layer will not learn additive biases.
bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias of QKVO projections, only works when :attr:`use_bias=True`.
Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
apply_residual_connection_post_layernorm : bool, default = False
indicate if apply residual connection with the output of layer normalization.
Indicate if apply residual connection with the output of layer normalization.
output_layernorm : bool, default = False
indicate if apply a layer normalization in the end of MHA.
Indicate if apply a layer normalization at the end of MHA.
attn_type: AttentionType, defult = AttentionType.PADDING
indicate the format of the attentino mask in the core attention.
Indicate the format of the attention mask in the core attention.
Optimization parameters
-----------------------
dtype :jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
The data type used to allocate the initial parameters.
fuse_qkv: bool, default = True
if set to True, this module exposes a single fused
If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for
cross-attention.
transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
scale_attn_logits: bool, default = False
indicate whether to scale attention logits.
if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
Indicate whether to scale attention logits.
If set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
else :math:`Q*K`
scaled_query_init: bool, default = `True`
whether to scale WQ on initilization by :math:`\sqrt{head_dim}`
Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
float32_logits : bool, default = False
whether to compute attention logits in float32.
Whether to compute attention logits in float32.
"""
head_dim: int
......@@ -267,7 +269,31 @@ class MultiHeadAttention(nn.Module):
*,
decode: bool = False,
deterministic: bool = False) -> Array:
"""Applies multi-head dot product attention on the input data."""
"""
MultiHeadAttention Layer:
[Query, Key, Value projection] -> Dot Product Attention -> Output projection.
Parameters
----------
inputs_q : jax.numpy.ndarray
Input tensor for query projection.
inputs_kv : jax.numpy.ndarray
Input tensor for key/value projection.
mask : jax.numpy.ndarray, default = None
Boolean tensor used to mask out self-attention softmax input.
bias : jax.numpy.ndarray, default = None
A tensor used to shift self-attention softmax input.
*
decode : bool,default = False
Indicate whether to prepare and use an autoregressive cache.
deterministic : bool,default = False
Disable dropout layers if set to True.
Returns
-------
outputs : jax.numpy.ndarray
Output tensors.
"""
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
......@@ -512,21 +538,21 @@ class RelativePositionBiases(nn.Module):
Parameters
----------
num_buckets : int
the number of buckets to bucket distances between key and query positions into.
The number of buckets to bucket distances between key and query positions into.
max_distance : int
the maximum distance before everything is lumped into the last
The maximum distance before everything is lumped into the last
distance bucket.
num_attention_heads : int
number of attention heads in the transformer layer.
Number of attention heads in the transformer layer.
embedding_init : Initializer, default = flax.linen.linear.default_embed_init
used for initializing relative embedding tables.
Used for initializing relative embedding tables.
embedding_axes : Tuple[str, ...], default = ('heads', 'relpos_buckets')
the name of axes used to shard embedding attention bias with a corresponding mesh.
The name of axes used to shard embedding attention bias with a corresponding mesh.
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
The data type used to allocate the initial parameters.
"""
num_buckets: int
max_distance: int
......@@ -543,11 +569,11 @@ class RelativePositionBiases(nn.Module):
Parameters
----------
q_seqlen : int
the sequence length of query.
The sequence length of query.
k_seqlen : int
the sequence length of key.
The sequence length of key.
bidirectional : bool, default = True
indicate whether to allow positive memory-query relative position
Indicate whether to allow positive memory-query relative position
embeddings.
Returns
......@@ -598,7 +624,16 @@ class RelativePositionBiases(nn.Module):
class TransformerLayerType(Enum):
"""TransformerLayerType."""
r"""
TransformerLayerType is an Enum class to specify a type of TransformerLayer
Values
----------
ENCODER:
Encoder type of TransformerLayer.
DECODER:
Decoder type of TransformerLayer.
"""
ENCODER = "encoder"
DECODER = "decoder"
......@@ -612,56 +647,59 @@ class TransformerLayer(nn.Module):
Parameters
----------
hidden_size: int, default = 512
the hidden size of each input sample.
The hidden size of each input sample.
mlp_hidden_size: int, default = 2048
intermediate size to which input samples are projected.
Intermediate size to which input samples are projected.
num_attention_heads: int, default = 8
number of attention heads in the transformer layer.
Number of attention heads in the transformer layer.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization.
Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability.
A value added to the denominator of layer normalization for numerical stability.
hidden_dropout: float, default = 0.1
dropout probability for the dropout op after FC2 layer.
Dropout probability for the dropout op after FC2 layer.
hidden_dropout_dims: Sequence[int], default = ()
dimensions that will share the same dropout mask for hidden
Dimensions that will share the same dropout mask for hidden
attention_dropout: float, default = 0.1
dropout probability for the dropout op during multi-head attention.
Dropout probability for the dropout op during multi-head attention.
dropout_rng_name: str, default = 'dropout'
the key in given RNGs via flax.linen.Module.apply that for
generate Dropout masks in the Multi-Head Attention.
The key in given RNGs via flax.linen.Module.apply that for
generating Dropout masks in the Multi-Head Attention.
mha_kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
used for initializing weights of QKV and Output projection weights.
Used for initializing weights of QKV and Output projection weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
mlp_kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
used for initializing weights of FC1 and FC2 layers.
Used for initializing weights of FC1 and FC2 layers.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
mlp_activations: Sequence[str], default = ('relu', )
the sequence of activation functions to apply after the first linear transformation.
The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer.
use_bias: bool, default = False
indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
if set to False, the layer will not learn additive biases.
Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
If set to False, the layer will not learn additive biases.
bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias of QKVO projections,
FC1 and FC2, only works when :attr:`use_bias=True`.
Used for initializing bias of QKVO projections,
FC1 and FC2. It is only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
apply_residual_connection_post_layernorm: bool, default = False
if set to True, residual connections are taken from the output
If set to True, residual connections are taken from the output
of layer norm (default is taken from input of layer norm)
output_layernorm: bool, default = False
if set to True, layer normalization is applied on the output side,
If set to True, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation.
float32_attention_logits: bool, default = False
if set to True, attention logits are executed in jax.numpy.float32.
If set to True, attention logits are executed in jax.numpy.float32.
layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
if set to TransformerLayerType.DECODER, an additional cross-attention block
If set to TransformerLayerType.DECODER, an additional cross-attention block
is added after self-attention.this can be used for structures like `T5`
Transformer in conjunction with the TransformerLayerType.ENCODER option.
enable_relative_embedding: bool, default = True
whether to enable relative embedding as shifting of attention logits.
Whether to enable relative embedding as shifting of attention logits.
relative_embedding: flax.linen.Module, default = None
the module for relative embedding execution, only works when
The module for relative embedding execution, only used when
:attr:`enable_relative_embedding=True`. Default is None, which will create
an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
Default: RelativePositionBiases( num_buckets=32, max_distance=128,
......@@ -672,24 +710,24 @@ class TransformerLayer(nn.Module):
Optimization parameters
-----------------------
dtype :jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
The data type used to allocate the initial parameters.
drop_path: float, default = 0.0
when > 0.0, applies stochastic depth per sample in the main
When > 0.0, applies stochastic depth per sample in the main
path of the residual block.
fuse_qkv_params: bool, default = True
if set to True, `TransformerLayer` module exposes a single fused
If set to True, `TransformerLayer` module exposes a single fused
parameter for query-key-value for self-attention and key-value for
cross-attention.
transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
scale_attn_logits: bool, default = False
indicate whether to scale attention logits.
Indicate whether to scale attention logits.
if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
else :math:`Q*K`
scaled_query_init: bool, default = `True`
whether to scale WQ on initilization by :math:`\sqrt{head_dim}`
Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
"""
hidden_size: int = 512
......@@ -752,7 +790,7 @@ class TransformerLayer(nn.Module):
Boolean tensor used to mask out cross-attention softmax input when
:attr:`layer_type=TransformerLayerType.DECODER`.
deterministic: bool, default = False
Disables dropout layers if set to True.
Disable dropout layers if set to True.
decode: bool,default = False
Indicate whether to prepare and use an autoregressive cache
in Multi-head attention (MHA).
......@@ -764,7 +802,7 @@ class TransformerLayer(nn.Module):
Returns
-------
outputs : jax.numpy.ndarray
Output tensors of this transformer block.
Output tensors.
"""
assert self.layer_type in TransformerLayerType, \
"layer_type should be one of TransformerLayerType" \
......
......@@ -869,7 +869,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
py::enum_<transformer_engine::DType>(m, "DType")
py::enum_<transformer_engine::DType>(m, "DType", py::module_local())
.value("kByte", transformer_engine::DType::kByte)
.value("kInt32", transformer_engine::DType::kInt32)
.value("kFloat32", transformer_engine::DType::kFloat32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment