Unverified Commit 05366e5f authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

Adding JAX to README.rst (#98)



* Adding JAX to README.rst
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Refine README.rst as the suggestion from review.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

* Refine the API doc of extend_logical_axis_rules.
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent cfa666ac
...@@ -14,11 +14,11 @@ Transformer Engine (TE) is a library for accelerating Transformer models on NVID ...@@ -14,11 +14,11 @@ Transformer Engine (TE) is a library for accelerating Transformer models on NVID
using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower
memory utilization in both training and inference. TE provides a collection of highly optimized memory utilization in both training and inference. TE provides a collection of highly optimized
building blocks for popular Transformer architectures and an automatic mixed precision-like API that building blocks for popular Transformer architectures and an automatic mixed precision-like API that
can be used seamlessly with your PyTorch code. TE also includes a framework agnostic C++ API that can be used seamlessly with your own framework-specific code. TE also includes a framework agnostic
can be integrated with other deep learning libraries to enable FP8 support for Transformers. C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.
As the number of parameters in Transformer models continues to grow, training and inference for As the number of parameters in Transformer models continues to grow, training and inference for
architectures such as BERT, GPT and T5 becomes very memory and compute intensive. Most deep learning architectures such as BERT, GPT and T5 become very memory and compute intensive. Most deep learning
frameworks train with FP32 by default. This is not essential, however, to achieve full accuracy for frameworks train with FP32 by default. This is not essential, however, to achieve full accuracy for
many deep learning models. Using mixed-precision training, which combines single-precision (FP32) many deep learning models. Using mixed-precision training, which combines single-precision (FP32)
with lower precision (e.g. FP16) format when training a model, results in significant speedups with with lower precision (e.g. FP16) format when training a model, results in significant speedups with
...@@ -28,13 +28,17 @@ degradation in accuracy. Although all major deep learning frameworks support FP1 ...@@ -28,13 +28,17 @@ degradation in accuracy. Although all major deep learning frameworks support FP1
not available today. not available today.
TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language
Model (LLM) libraries. It provides python layer (initially supporting pyTorch, with support for more Model (LLM) libraries. It provides python layer consisting of modules to easily build Transformer
frameworks in the future) consisting of modules to easily build Transformer layer as well as layer as well as framework agnostic library in C++ including structs and kernels needed for FP8 support.
framework agnostic library in C++ including structs and kernels needed for FP8 support. Modules Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly
provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly
simplifying for the users. simplifying for the users.
Transformer Engine in action:
Examples
--------
pyTorch
^^^^^^^
.. code-block:: python .. code-block:: python
...@@ -51,23 +55,67 @@ Transformer Engine in action: ...@@ -51,23 +55,67 @@ Transformer Engine in action:
model = te.Linear(in_features, out_features, bias=True) model = te.Linear(in_features, out_features, bias=True)
inp = torch.randn(hidden_size, in_features, device="cuda") inp = torch.randn(hidden_size, in_features, device="cuda")
# Create FP8 recipe. Note: All input args are optional. # Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3) fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
# Enables autocasting for the forward pass # Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = model(inp) out = model(inp)
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
JAX
^^^
.. code-block:: python
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.common import recipe
BATCH = 32
SEQLEN = 128
HIDDEN = 1024
# Initialize RNG and inputs.
rng = jax.random.PRNGKey(0)
init_rng, data_rng = jax.random.split(rng)
inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.HYBRID)
# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
model = te.DenseGeneral(features=HIDDEN)
def loss_fn(params, other_vars, inp):
out = model.apply({'params':params, **other_vars}, inp)
return jnp.mean(out)
# Initialize models.
variables = model.init(init_rng, inp)
other_variables, params = variables.pop('params')
# Construct the forward and backward function
fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))
for _ in range(10):
loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)
# Update FP8 metas
other_variables = te.update_fp8_metas(other_grads)
Highlights Highlights
---------- ----------
* Easy-to-use pyTorch modules enabling building of the Transformer layers with FP8 support on H100 * Easy-to-use modules enabling building of the Transformer layers with FP8 support
GPUs. on H100 GPUs.
* Optimizations (e.g. fused kernels) for Transformer models across all precisions and NVIDIA GPU * Optimizations (e.g. fused kernels) for Transformer models across all precisions and NVIDIA GPU
architecures. architectures.
.. overview-end-marker-do-not-remove .. overview-end-marker-do-not-remove
...@@ -87,7 +135,9 @@ Clone the repository and inside it type: ...@@ -87,7 +135,9 @@ Clone the repository and inside it type:
.. code-block:: bash .. code-block:: bash
pip install . NVTE_FRAMEWORK=all pip install . # Building with all frameworks.
NVTE_FRAMEWORK=pytorch pip install . # Building with pyTorch only.
NVTE_FRAMEWORK=jax pip install . # Building with JAX only.
User Guide User Guide
---------- ----------
...@@ -102,6 +152,9 @@ While the more granular modules in Transformer Engine allow building any Transfo ...@@ -102,6 +152,9 @@ While the more granular modules in Transformer Engine allow building any Transfo
the `TransformerLayer` API of Transformer Engine is flexible enough to build multiple major the `TransformerLayer` API of Transformer Engine is flexible enough to build multiple major
variations of Transformers. variations of Transformers.
NOTE: For simplicity, we only show pyTorch examples below. For the usage of `TransformerLayer`
of all supported frameworks, refer to `examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_.
GPT GPT
^^^ ^^^
......
...@@ -45,8 +45,10 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: ...@@ -45,8 +45,10 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
logical axis rules. logical axis rules.
.. note:: .. note::
We currently only support single, data parallelism and standard tensor parallelism We currently only support logical axis rules for single GPU training, data parallel
logical axis rules for performance reasons. training and 1D-sharding tensor parallel training.
Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
for 1D-sharding tensor parallelism.
.. warning:: .. warning::
Please make sure ShardingResource is set via fp8_autocast before calling this function. Please make sure ShardingResource is set via fp8_autocast before calling this function.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment