Unverified Commit ed1a3116 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

Adding documents to TE/JAX (#87)



* Updated TE/JAX docs
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding TE/JAX docs' rst files
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Set DType as pybind11::module_local() to avoid generic_type errors.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Updating license and exporting more modules
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adopting autoapi and removing enum_tools.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix typo
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Make jax.rst be style consistent.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fixing doc statements as the suggestion from review.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Fixing doc statements as the suggestion from code review.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Update the description of Softmax
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Removed categories in catalog as PyTorch
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d74ee5b5
...@@ -9,3 +9,4 @@ Framework-specific API ...@@ -9,3 +9,4 @@ Framework-specific API
.. toctree:: .. toctree::
pytorch pytorch
jax
..
Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Jax
=======
.. autoapiclass:: transformer_engine.jax.MajorShardingType
.. autoapiclass:: transformer_engine.jax.ShardingType
.. autoapiclass:: transformer_engine.jax.TransformerLayerType
.. autoapiclass:: transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None)
.. autoapiclass:: transformer_engine.jax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.MultiHeadAttention(head_dim, num_heads, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs)
:members: __call__
.. autoapifunction:: transformer_engine.jax.extend_logical_axis_rules
.. autoapifunction:: transformer_engine.jax.fp8_autocast
.. autoapifunction:: transformer_engine.jax.update_collections
.. autoapifunction:: transformer_engine.jax.update_fp8_metas
\ No newline at end of file
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for JAX""" """Transformer Engine bindings for JAX"""
from .fp8 import fp8_autocast, update_collections, update_fp8_metas from .fp8 import fp8_autocast, update_collections, update_fp8_metas
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules from .transformer import extend_logical_axis_rules
from .transformer import RelativePositionBiases, TransformerLayer, TransformerLayerType from .transformer import MultiHeadAttention, RelativePositionBiases
from .sharding import ShardingResource from .transformer import TransformerLayer, TransformerLayerType
from .sharding import MajorShardingType, ShardingResource, ShardingType
...@@ -53,7 +53,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -53,7 +53,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
pybind11::enum_<DType>(m, "DType") pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
.value("kInt32", DType::kInt32) .value("kInt32", DType::kInt32)
.value("kFloat32", DType::kFloat32) .value("kFloat32", DType::kFloat32)
......
...@@ -302,16 +302,16 @@ def fp8_autocast(enabled: bool = False, ...@@ -302,16 +302,16 @@ def fp8_autocast(enabled: bool = False,
.. note:: .. note::
We only support :attr:`margin`, :attr:`fp8_format`, :attr:`interval` and We only support :attr:`margin`, :attr:`fp8_format`, :attr:`interval` and
:attr:`amax_history_len` in recipe.DelayedScaling currently. Other parameters :attr:`amax_history_len` in recipe.DelayedScaling currently. Other parameters
in recipe.DelayedScaling would be ignored, even is set. in recipe.DelayedScaling would be ignored, even if set.
Parameters Parameters
---------- ----------
enabled: bool, default = False enabled: bool, default = False
whether or not to enable fp8 Whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None fp8_recipe: recipe.DelayedScaling, default = None
recipe used for FP8 training. Recipe used for FP8 training.
sharding_resource: ShardingResource, defaule = None sharding_resource: ShardingResource, default = None
specify the mesh axes for data and tensor parallelism to shard along. Specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then ShardingResource() would be created. If set to None, then ShardingResource() would be created.
""" """
if fp8_recipe is None: if fp8_recipe is None:
...@@ -338,12 +338,11 @@ def fp8_autocast(enabled: bool = False, ...@@ -338,12 +338,11 @@ def fp8_autocast(enabled: bool = False,
# Function Wrappers # Function Wrappers
def update_collections(new: Collection, original: Collection) -> Collection: def update_collections(new: Collection, original: Collection) -> FrozenDict:
r""" r"""
A helper to update Flax's Collection. Collection is a union type of dict and A helper to update Flax's Collection.
Flax's FrozenDict.
Collection = [dict, FrozenDict] Collection = [dict, flax.core.frozen_dict.FrozenDict]
Parameters Parameters
---------- ----------
...@@ -364,14 +363,16 @@ def update_fp8_metas(state: Collection) -> Collection: ...@@ -364,14 +363,16 @@ def update_fp8_metas(state: Collection) -> Collection:
r""" r"""
Calculate new fp8 scales and its inverse via the followed formula Calculate new fp8 scales and its inverse via the followed formula
`exp` = floor(log2(`fp8_max` / `amax`)) - `margin` .. code-block:: python
`sf` = round(power(2, abs(exp)))
`sf` = `sf` if `amax` > 0.0, else original_scale exp = floor(log2(fp8_max / amax)) - margin
`sf` = `sf` if isfinite(`amax`), else original_scale) sf = round(power(2, abs(exp)))
`updated_scale` = `1/sf` if exp < 0, else `sf` sf = sf if amax > 0.0, else original_scale
`updated_scale_inv` = `1/updated_scale` sf = sf if isfinite(amax), else original_scale)
updated_scale = 1/sf if exp < 0, else sf
updated_scale_inv = 1/updated_scale
Collection = [dict, FrozenDict] Collection = [dict, flax.core.frozen_dict.FrozenDict]
Parameters Parameters
---------- ----------
......
...@@ -98,19 +98,25 @@ def _combine_biases(*masks: List[Array]): ...@@ -98,19 +98,25 @@ def _combine_biases(*masks: List[Array]):
class Softmax(nn.Module): class Softmax(nn.Module):
r""" r"""
Applies softmax over a mini-batch of inputs. Applies softmax over a mini-batch of inputs.
The inputs's shape should be [batch, heads, q_seqlen, k_seqlen]. The input's shape should be [batch, heads, q_seqlen, k_seqlen].
.. code-block:: python
shifted_input = input + bias
masked_scaled = (1 - mask)*(shifted_input * scale_factor)
softmax_mask = mask * -1e-10
output = softmax(masked_scaled + softmax_mask)
Parameters Parameters
---------- ----------
scale_factor : float, default = 1.0 scale_factor : float, default = 1.0
scale the inputs along the last dimension before running softmax. Scalar for the input to softmax.
softmax_type : SoftmaxType, default = 'layernorm' softmax_type : SoftmaxType, default = SoftmaxType.SCALED
indicate the type of softmax. Indicate the type of softmax.
Optimization parameters Optimization parameters
----------------------- -----------------------
sharding_type : ShardingType, default = ShardingType.SINGLE sharding_type : ShardingType, default = ShardingType.SINGLE
indicate the sharding pattern. Indicate the sharding pattern.
""" """
scale_factor: float = 1.0 scale_factor: float = 1.0
...@@ -158,7 +164,7 @@ class Softmax(nn.Module): ...@@ -158,7 +164,7 @@ class Softmax(nn.Module):
outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED, outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED,
self.sharding_type) self.sharding_type)
else: else:
outputs = jax_nn.softmax(logits) outputs = jax_nn.softmax(logits * self.scale_factor)
return outputs return outputs
...@@ -193,30 +199,32 @@ class LayerNorm(nn.Module): ...@@ -193,30 +199,32 @@ class LayerNorm(nn.Module):
Parameters Parameters
---------- ----------
epsilon : float, default = 1e-6 epsilon : float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization. Indicate the type of layer normalization.
scale_init : Initializer, default = flax.linen.initializers.ones scale_init : Initializer, default = flax.linen.initializers.ones
used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
the name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh. The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
bias_init : Initializer, default = flax.linen.initializers.zeros bias_init : Initializer, default = flax.linen.initializers.zeros
used for initializing shift factors :math:`\beta`, Used for initializing shift factors :math:`\beta`,
only works when :attr:`layernorm_type='layernorm'`. only used when :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
bias_axes : Tuple[str, ...], default = ('embed', ) bias_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
only works when :attr:`layernorm_type='layernorm'`. only used when :attr:`layernorm_type='layernorm'`.
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32 dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters. the data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
sharding_type : ShardingType, default = ShardingType.SINGLE sharding_type : ShardingType, default = ShardingType.SINGLE
indicate the sharding pattern. Indicate the sharding pattern.
""" """
epsilon: float = 1e-6 epsilon: float = 1e-6
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
...@@ -316,33 +324,35 @@ class DenseGeneral(TransformerEngineBase): ...@@ -316,33 +324,35 @@ class DenseGeneral(TransformerEngineBase):
Parameters Parameters
---------- ----------
features : Union[Iterable[int], int] features : Union[Iterable[int], int]
the hidden size of each output sample. The hidden size of each output sample.
kernel_init : Initializer, default = kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
used for initializing weights. Used for initializing weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
kernel_axes : Tuple[str, ...], default = () kernel_axes : Tuple[str, ...], default = ()
the name of axes used to shard the weights with a corresponding mesh. The name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False use_bias: bool, default = False
indicate whether to enable bias shifting. Indicate whether to enable bias shifting.
if set to False, the layer will not learn an additive bias. If set to False, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias, only works when :attr:`use_bias=True`. Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
bias_axes: Tuple[str, ...], default = () bias_axes: Tuple[str, ...], default = ()
the name of axes used to shard bias with a corresponding mesh, The name of axes used to shard bias with a corresponding mesh,
only works when :attr:`use_bias=True`. only used when :attr:`use_bias=True`.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
a integer of tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32 dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
sharding_type : ShardingType, default = ShardingType.SINGLE sharding_type : ShardingType, default = ShardingType.SINGLE
indicate the sharding pattern. Indicate the sharding pattern.
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
...@@ -426,56 +436,60 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -426,56 +436,60 @@ class LayerNormDenseGeneral(TransformerEngineBase):
Parameters Parameters
---------- ----------
features : Union[Iterable[int], int] features : Union[Iterable[int], int]
the hidden size of each output sample. The hidden size of each output sample.
enable_layernorm: bool, default = True enable_layernorm: bool, default = True
indicate whether to enable layer normalization before linear transformation. Indicate whether to enable layer normalization before linear transformation.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization. Indicate the type of layer normalization.
epsilon : float, default = 1e-6 epsilon : float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
scale_init : Initializer, default = flax.linen.initializers.ones scale_init : Initializer, default = flax.linen.initializers.ones
used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
the name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh, The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only works when :attr:`enable_layernorm=True`. only used when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros ln_bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing shift factors :math:`\beta`, Used for initializing shift factors :math:`\beta`,
only works when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
ln_bias_axes: Tuple[str, ...], default = ('embed', ) ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
only works when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default = kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
used for initializing weights. Used for initializing weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
kernel_axes : Tuple[str, ...], default = () kernel_axes : Tuple[str, ...], default = ()
the name of axes used to shard the weights with a corresponding mesh. The name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False use_bias: bool, default = False
indicate whether to enable bias shifting. Indicate whether to enable bias shifting.
if set to False, the layer will not learn an additive bias. If set to False, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias, only works when :attr:`use_bias=True`. Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
bias_axes: Tuple[str, ...], default = () bias_axes: Tuple[str, ...], default = ()
the name of axes used to shard bias with a corresponding mesh, The name of axes used to shard bias with a corresponding mesh,
only works when :attr:`use_bias=True`. only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True return_layernorm_output: bool, default = True
indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set False, return None as the second tensor in outputs.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
a integer of tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32 dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
depth_scaling: float, default = None depth_scaling: float, default = None
the factor to scale the output from `DenseGeneral`. It should be a float The factor to scale the output from `DenseGeneral`. It should be a float
value or None. When None is set, then no scaling is applied. value or None. When None is set, then no scaling is applied.
sharding_type : ShardingType, default = ShardingType.SINGLE sharding_type : ShardingType, default = ShardingType.SINGLE
indicate the sharding pattern. Indicate the sharding pattern.
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
...@@ -519,7 +533,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -519,7 +533,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
Output tensors. Output tensors.
ln_outputs: jax.numpy.ndarray ln_outputs: jax.numpy.ndarray
The output tensors of layer normalization. The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this woulb be None. If :attr:`return_layernorm_output=False`, then this would be None.
""" """
ln_output = None ln_output = None
...@@ -617,67 +631,71 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -617,67 +631,71 @@ class LayerNormMLP(TransformerEngineBase):
Parameters Parameters
---------- ----------
intermediate_dim: int, default = 2048 intermediate_dim: int, default = 2048
intermediate size to which input samples are projected. Intermediate size to which input samples are projected.
enable_layernorm: bool, default = True enable_layernorm: bool, default = True
indicate whether to enable layer normalization before linear transformation. Indicate whether to enable layer normalization before linear transformation.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization. Indicate the type of layer normalization.
epsilon : float, default = 1e-6 epsilon : float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
scale_init : Initializer, default = flax.linen.initializers.ones scale_init : Initializer, default = flax.linen.initializers.ones
used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
the name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh, The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only works when :attr:`enable_layernorm=True`. only used when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros ln_bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing shift factors :math:`\beta`, Used for initializing shift factors :math:`\beta`,
only works when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
ln_bias_axes: Tuple[str, ...], default = ('embed', ) ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
only works when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default = kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
used for initializing weight of both linear transformations. Used for initializing the weights of both linear transformations.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp') kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
the name of axes used to shard the weights with a corresponding mesh for The name of axes used to shard the weights with a corresponding mesh for
the weight of the first linear transformations. the weight of the first linear transformations.
kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed') kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
the name of axes used to shard the weights with a corresponding mesh for The name of axes used to shard the weights with a corresponding mesh for
the weight of the second linear transformations. the weight of the second linear transformations.
use_bias: bool, default = False use_bias: bool, default = False
indicate whether to enable bias shifting. Indicate whether to enable bias shifting.
if set to False, the layer will not learn an additive bias. If set to False, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias, only works when :attr:`use_bias=True`. Used for initializing bias, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
bias_axes_1: Tuple[str, ...], default = ('mlp',) bias_axes_1: Tuple[str, ...], default = ('mlp',)
the name of axes used to shard bias with a corresponding mesh for The name of axes used to shard bias with a corresponding mesh for
the weight of the first linear transformations. the weight of the first linear transformations.
only works when :attr:`use_bias=True`. Only used when :attr:`use_bias=True`.
bias_axes_2: Tuple[str, ...], default = ('embed',) bias_axes_2: Tuple[str, ...], default = ('embed',)
the name of axes used to shard bias with a corresponding mesh for The name of axes used to shard bias with a corresponding mesh for
the weight of the second linear transformations. the weight of the second linear transformations.
only works when :attr:`use_bias=True`. Only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True return_layernorm_output: bool, default = True
indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set False, return None as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('relu',) activations: Sequence[Union[str, Callable]], default = ('relu',)
the sequence of activation functions to apply after the first linear transformation. The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
intermediate_dropout_rate: float, default = 0.1 intermediate_dropout_rate: float, default = 0.1
dropout probability for the dropout op after the :attr:`activations`. Dropout probability for the dropout op after the :attr:`activations`.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
a integer of tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32 dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
major_sharding_type : MajorShardingType, default = MajorShardingType.SINGLE major_sharding_type : MajorShardingType, default = MajorShardingType.SINGLE
indicate the sharding pattern. Indicate the sharding pattern.
""" """
intermediate_dim: int = 2048 intermediate_dim: int = 2048
...@@ -726,7 +744,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -726,7 +744,7 @@ class LayerNormMLP(TransformerEngineBase):
Output tensors. Output tensors.
ln_outputs: jax.numpy.ndarray ln_outputs: jax.numpy.ndarray
The output tensors of layer normalization. The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this woulb be None. If :attr:`return_layernorm_output=False`, then this would be None.
""" """
ln_output = None ln_output = None
......
...@@ -36,11 +36,11 @@ class ShardingResource: ...@@ -36,11 +36,11 @@ class ShardingResource:
Parameters Parameters
---------- ----------
dp_resource : str, default = None dp_resource : str, default = None
axis name in Mesh used to shard batch along. The axis name in Mesh used to shard batches along.
if it is None, then disabling data parallelism. If it is None, then data parallelism is disabled.
tp_resource : str, default = None tp_resource : str, default = None
axis name in Mesh used to split model tensor along. The axis name in Mesh used to split the hidden dimensions along.
if it is None, then disabling tensor parallelism. If it is None, then tensor parallelism is disabled.
""" """
dp_resource: str = None dp_resource: str = None
tp_resource: str = None tp_resource: str = None
...@@ -71,12 +71,19 @@ def global_shard_resource() -> ShardingResource: ...@@ -71,12 +71,19 @@ def global_shard_resource() -> ShardingResource:
class MajorShardingType(Enum): class MajorShardingType(Enum):
""" r"""
The major sharding type to indicate sharding pattern. The major sharding type to indicate sharding pattern.
`SINGLE` means single process training.
`DP` means data parallel traiing. Values
`TP` means tensor parallel traiing. ----------
`DPTP` means data and tensor parallel traiing. SINGLE:
Single process training.
DP:
Data parallel training.
TP:
Standard tensor parallel training.
DPTP:
Data and Standard tensor parallel training.
""" """
SINGLE = 0 SINGLE = 0
DP = 1 DP = 1
...@@ -87,12 +94,21 @@ class MajorShardingType(Enum): ...@@ -87,12 +94,21 @@ class MajorShardingType(Enum):
class ShardingType(Enum): class ShardingType(Enum):
""" """
The sharding type to indicate sharding pattern. The sharding type to indicate sharding pattern.
`SINGLE` means no sharding.
`DP` means sharding along data parallelism. Values
`TP_COL` means sharding along column-split tensor parallelism. ----------
`TP_ROW` means sharding along row-split tensor parallelism. SINGLE:
`DP_TP_COL` means sharding along data and column-split tensor parallelism. No sharding.
`DP_TP_ROW` means sharding along data and row-split tensor parallelism. DP:
Sharding along data parallelism.
TP_COL:
Sharding along column-split tensor parallelism.
TP_ROW:
Sharding along row-split tensor parallelism.
DP_TP_COL:
Sharding along data and column-split tensor parallelism.
DP_TP_ROW:
Sharding along data and row-split tensor parallelism.
""" """
SINGLE = (MajorShardingType.SINGLE, "single") SINGLE = (MajorShardingType.SINGLE, "single")
DP = (MajorShardingType.DP, "dp") DP = (MajorShardingType.DP, "dp")
......
...@@ -41,7 +41,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[ ...@@ -41,7 +41,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[
def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
""" """
Extend the given Flax logical axis rules with the pre-defined TransformerLayer's Extend the given Flax logical axis rules with the predefined TransformerLayer's
logical axis rules. logical axis rules.
.. note:: .. note::
...@@ -185,53 +185,55 @@ class MultiHeadAttention(nn.Module): ...@@ -185,53 +185,55 @@ class MultiHeadAttention(nn.Module):
Parameters Parameters
---------- ----------
head_dim : int head_dim : int
the hidden dimension of each attention heads. The hidden dimension of each attention head.
num_heads : int num_heads : int
the number of attention heads The number of attention heads
dropout_rate : float, default = 0.0 dropout_rate : float, default = 0.0
dropout probability for the dropout op during multi-head attention. Dropout probability for the dropout op during multi-head attention.
dropout_rng_name: str, default = 'dropout' dropout_rng_name: str, default = 'dropout'
the key in given RNGs via flax.linen.Module.apply that The key in given RNGs via flax.linen.Module.apply that is used
for generate Dropout masks in the core attention. to generate Dropout masks in the core attention.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization. Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6 layernorm_epsilon: float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
kernel_init: Initializer, default = kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
used for initializing weights of QKV and Output projection weights. Used for initializing the QKV and Output projection weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
use_bias: bool, default = False use_bias: bool, default = False
indicate whether to enable bias shifting for QKVO projections. Indicate whether or not to enable bias shifting for QKVO projections.
if set to False, the layer will not learn additive biases. If set to False, the layer will not learn additive biases.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias of QKVO projections, only works when :attr:`use_bias=True`. Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
apply_residual_connection_post_layernorm : bool, default = False apply_residual_connection_post_layernorm : bool, default = False
indicate if apply residual connection with the output of layer normalization. Indicate if apply residual connection with the output of layer normalization.
output_layernorm : bool, default = False output_layernorm : bool, default = False
indicate if apply a layer normalization in the end of MHA. Indicate if apply a layer normalization at the end of MHA.
attn_type: AttentionType, defult = AttentionType.PADDING attn_type: AttentionType, defult = AttentionType.PADDING
indicate the format of the attentino mask in the core attention. Indicate the format of the attention mask in the core attention.
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype :jax.numpy.dtype, default = jax.numpy.float32 dtype :jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
fuse_qkv: bool, default = True fuse_qkv: bool, default = True
if set to True, this module exposes a single fused If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for parameter for query-key-value for self-attention and key-value for
cross-attention. cross-attention.
transpose_batch_sequence : bool, default = True transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
scale_attn_logits: bool, default = False scale_attn_logits: bool, default = False
indicate whether to scale attention logits. Indicate whether to scale attention logits.
if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`, If set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
else :math:`Q*K` else :math:`Q*K`
scaled_query_init: bool, default = `True` scaled_query_init: bool, default = `True`
whether to scale WQ on initilization by :math:`\sqrt{head_dim}` Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
float32_logits : bool, default = False float32_logits : bool, default = False
whether to compute attention logits in float32. Whether to compute attention logits in float32.
""" """
head_dim: int head_dim: int
...@@ -267,7 +269,31 @@ class MultiHeadAttention(nn.Module): ...@@ -267,7 +269,31 @@ class MultiHeadAttention(nn.Module):
*, *,
decode: bool = False, decode: bool = False,
deterministic: bool = False) -> Array: deterministic: bool = False) -> Array:
"""Applies multi-head dot product attention on the input data.""" """
MultiHeadAttention Layer:
[Query, Key, Value projection] -> Dot Product Attention -> Output projection.
Parameters
----------
inputs_q : jax.numpy.ndarray
Input tensor for query projection.
inputs_kv : jax.numpy.ndarray
Input tensor for key/value projection.
mask : jax.numpy.ndarray, default = None
Boolean tensor used to mask out self-attention softmax input.
bias : jax.numpy.ndarray, default = None
A tensor used to shift self-attention softmax input.
*
decode : bool,default = False
Indicate whether to prepare and use an autoregressive cache.
deterministic : bool,default = False
Disable dropout layers if set to True.
Returns
-------
outputs : jax.numpy.ndarray
Output tensors.
"""
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
...@@ -512,21 +538,21 @@ class RelativePositionBiases(nn.Module): ...@@ -512,21 +538,21 @@ class RelativePositionBiases(nn.Module):
Parameters Parameters
---------- ----------
num_buckets : int num_buckets : int
the number of buckets to bucket distances between key and query positions into. The number of buckets to bucket distances between key and query positions into.
max_distance : int max_distance : int
the maximum distance before everything is lumped into the last The maximum distance before everything is lumped into the last
distance bucket. distance bucket.
num_attention_heads : int num_attention_heads : int
number of attention heads in the transformer layer. Number of attention heads in the transformer layer.
embedding_init : Initializer, default = flax.linen.linear.default_embed_init embedding_init : Initializer, default = flax.linen.linear.default_embed_init
used for initializing relative embedding tables. Used for initializing relative embedding tables.
embedding_axes : Tuple[str, ...], default = ('heads', 'relpos_buckets') embedding_axes : Tuple[str, ...], default = ('heads', 'relpos_buckets')
the name of axes used to shard embedding attention bias with a corresponding mesh. The name of axes used to shard embedding attention bias with a corresponding mesh.
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32 dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
""" """
num_buckets: int num_buckets: int
max_distance: int max_distance: int
...@@ -543,11 +569,11 @@ class RelativePositionBiases(nn.Module): ...@@ -543,11 +569,11 @@ class RelativePositionBiases(nn.Module):
Parameters Parameters
---------- ----------
q_seqlen : int q_seqlen : int
the sequence length of query. The sequence length of query.
k_seqlen : int k_seqlen : int
the sequence length of key. The sequence length of key.
bidirectional : bool, default = True bidirectional : bool, default = True
indicate whether to allow positive memory-query relative position Indicate whether to allow positive memory-query relative position
embeddings. embeddings.
Returns Returns
...@@ -598,7 +624,16 @@ class RelativePositionBiases(nn.Module): ...@@ -598,7 +624,16 @@ class RelativePositionBiases(nn.Module):
class TransformerLayerType(Enum): class TransformerLayerType(Enum):
"""TransformerLayerType.""" r"""
TransformerLayerType is an Enum class to specify a type of TransformerLayer
Values
----------
ENCODER:
Encoder type of TransformerLayer.
DECODER:
Decoder type of TransformerLayer.
"""
ENCODER = "encoder" ENCODER = "encoder"
DECODER = "decoder" DECODER = "decoder"
...@@ -612,56 +647,59 @@ class TransformerLayer(nn.Module): ...@@ -612,56 +647,59 @@ class TransformerLayer(nn.Module):
Parameters Parameters
---------- ----------
hidden_size: int, default = 512 hidden_size: int, default = 512
the hidden size of each input sample. The hidden size of each input sample.
mlp_hidden_size: int, default = 2048 mlp_hidden_size: int, default = 2048
intermediate size to which input samples are projected. Intermediate size to which input samples are projected.
num_attention_heads: int, default = 8 num_attention_heads: int, default = 8
number of attention heads in the transformer layer. Number of attention heads in the transformer layer.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization. Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6 layernorm_epsilon: float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
hidden_dropout: float, default = 0.1 hidden_dropout: float, default = 0.1
dropout probability for the dropout op after FC2 layer. Dropout probability for the dropout op after FC2 layer.
hidden_dropout_dims: Sequence[int], default = () hidden_dropout_dims: Sequence[int], default = ()
dimensions that will share the same dropout mask for hidden Dimensions that will share the same dropout mask for hidden
attention_dropout: float, default = 0.1 attention_dropout: float, default = 0.1
dropout probability for the dropout op during multi-head attention. Dropout probability for the dropout op during multi-head attention.
dropout_rng_name: str, default = 'dropout' dropout_rng_name: str, default = 'dropout'
the key in given RNGs via flax.linen.Module.apply that for The key in given RNGs via flax.linen.Module.apply that for
generate Dropout masks in the Multi-Head Attention. generating Dropout masks in the Multi-Head Attention.
mha_kernel_init: Initializer, default = mha_kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
used for initializing weights of QKV and Output projection weights. Used for initializing weights of QKV and Output projection weights.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
mlp_kernel_init: Initializer, default = mlp_kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
used for initializing weights of FC1 and FC2 layers. Used for initializing weights of FC1 and FC2 layers.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
mlp_activations: Sequence[str], default = ('relu', ) mlp_activations: Sequence[str], default = ('relu', )
the sequence of activation functions to apply after the first linear transformation. The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
use_bias: bool, default = False use_bias: bool, default = False
indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
if set to False, the layer will not learn additive biases. If set to False, the layer will not learn additive biases.
bias_init: Initializer, default = flax.linen.initializers.zeros bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias of QKVO projections, Used for initializing bias of QKVO projections,
FC1 and FC2, only works when :attr:`use_bias=True`. FC1 and FC2. It is only used when :attr:`use_bias=True`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
apply_residual_connection_post_layernorm: bool, default = False apply_residual_connection_post_layernorm: bool, default = False
if set to True, residual connections are taken from the output If set to True, residual connections are taken from the output
of layer norm (default is taken from input of layer norm) of layer norm (default is taken from input of layer norm)
output_layernorm: bool, default = False output_layernorm: bool, default = False
if set to True, layer normalization is applied on the output side, If set to True, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation. normalization on the input side, before the QKV transformation.
float32_attention_logits: bool, default = False float32_attention_logits: bool, default = False
if set to True, attention logits are executed in jax.numpy.float32. If set to True, attention logits are executed in jax.numpy.float32.
layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
if set to TransformerLayerType.DECODER, an additional cross-attention block If set to TransformerLayerType.DECODER, an additional cross-attention block
is added after self-attention.this can be used for structures like `T5` is added after self-attention.this can be used for structures like `T5`
Transformer in conjunction with the TransformerLayerType.ENCODER option. Transformer in conjunction with the TransformerLayerType.ENCODER option.
enable_relative_embedding: bool, default = True enable_relative_embedding: bool, default = True
whether to enable relative embedding as shifting of attention logits. Whether to enable relative embedding as shifting of attention logits.
relative_embedding: flax.linen.Module, default = None relative_embedding: flax.linen.Module, default = None
the module for relative embedding execution, only works when The module for relative embedding execution, only used when
:attr:`enable_relative_embedding=True`. Default is None, which will create :attr:`enable_relative_embedding=True`. Default is None, which will create
an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`. an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
Default: RelativePositionBiases( num_buckets=32, max_distance=128, Default: RelativePositionBiases( num_buckets=32, max_distance=128,
...@@ -672,24 +710,24 @@ class TransformerLayer(nn.Module): ...@@ -672,24 +710,24 @@ class TransformerLayer(nn.Module):
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype :jax.numpy.dtype, default = jax.numpy.float32 dtype :jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
drop_path: float, default = 0.0 drop_path: float, default = 0.0
when > 0.0, applies stochastic depth per sample in the main When > 0.0, applies stochastic depth per sample in the main
path of the residual block. path of the residual block.
fuse_qkv_params: bool, default = True fuse_qkv_params: bool, default = True
if set to True, `TransformerLayer` module exposes a single fused If set to True, `TransformerLayer` module exposes a single fused
parameter for query-key-value for self-attention and key-value for parameter for query-key-value for self-attention and key-value for
cross-attention. cross-attention.
transpose_batch_sequence : bool, default = True transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
scale_attn_logits: bool, default = False scale_attn_logits: bool, default = False
indicate whether to scale attention logits. Indicate whether to scale attention logits.
if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`, if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
else :math:`Q*K` else :math:`Q*K`
scaled_query_init: bool, default = `True` scaled_query_init: bool, default = `True`
whether to scale WQ on initilization by :math:`\sqrt{head_dim}` Whether to scale WQ on initialization by :math:`\sqrt{head_dim}`
""" """
hidden_size: int = 512 hidden_size: int = 512
...@@ -752,7 +790,7 @@ class TransformerLayer(nn.Module): ...@@ -752,7 +790,7 @@ class TransformerLayer(nn.Module):
Boolean tensor used to mask out cross-attention softmax input when Boolean tensor used to mask out cross-attention softmax input when
:attr:`layer_type=TransformerLayerType.DECODER`. :attr:`layer_type=TransformerLayerType.DECODER`.
deterministic: bool, default = False deterministic: bool, default = False
Disables dropout layers if set to True. Disable dropout layers if set to True.
decode: bool,default = False decode: bool,default = False
Indicate whether to prepare and use an autoregressive cache Indicate whether to prepare and use an autoregressive cache
in Multi-head attention (MHA). in Multi-head attention (MHA).
...@@ -764,7 +802,7 @@ class TransformerLayer(nn.Module): ...@@ -764,7 +802,7 @@ class TransformerLayer(nn.Module):
Returns Returns
------- -------
outputs : jax.numpy.ndarray outputs : jax.numpy.ndarray
Output tensors of this transformer block. Output tensors.
""" """
assert self.layer_type in TransformerLayerType, \ assert self.layer_type in TransformerLayerType, \
"layer_type should be one of TransformerLayerType" \ "layer_type should be one of TransformerLayerType" \
......
...@@ -869,7 +869,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -869,7 +869,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
py::enum_<transformer_engine::DType>(m, "DType") py::enum_<transformer_engine::DType>(m, "DType", py::module_local())
.value("kByte", transformer_engine::DType::kByte) .value("kByte", transformer_engine::DType::kByte)
.value("kInt32", transformer_engine::DType::kInt32) .value("kInt32", transformer_engine::DType::kInt32)
.value("kFloat32", transformer_engine::DType::kFloat32) .value("kFloat32", transformer_engine::DType::kFloat32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment