"A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their $\\text{sequence_length} \\times \\text{batch_size} \\times \\text{hidden_size}$ activation tensors. The most common approach is data parallelism, which distributes along the $\\text{batch_size}$ dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the $\\text{hidden_size}$ dimension. This allows us to scale past the limits of data parallelism (typically $\\text{hidden_size} > \\text{batch_size}$) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Finally, sequence parallelism distributes along the $\\text{sequence_length}$ dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see [this paper](https://arxiv.org/pdf/2205.05198.pdf).\n",
"A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their $\\text{sequence_length} \\cdot \\text{batch_size} \\cdot \\text{hidden_size}$ activation tensors. The most common approach is data parallelism, which distributes along the $\\text{batch_size}$ dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the $\\text{hidden_size}$ dimension. This allows us to scale past the limits of data parallelism (typically $\\text{hidden_size} > \\text{batch_size}$) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Finally, sequence parallelism distributes along the $\\text{sequence_length}$ dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see [this paper](https://arxiv.org/pdf/2205.05198.pdf).\n",
"\n",
"\n",
"To show this in action, let's first initialize NCCL with a trivial process group:"
"To show this in action, let's first initialize NCCL with a trivial process group:"
]
]
...
@@ -131,7 +131,7 @@
...
@@ -131,7 +131,7 @@
"id": "1f2b80d0",
"id": "1f2b80d0",
"metadata": {},
"metadata": {},
"source": [
"source": [
"We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\times \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n",
"We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\cdot \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n",
"\n",
"\n",
"Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n",
"Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n",
"* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
"* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
" - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
" - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
" - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
" - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor of shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors of shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
"\n",
"\n",
"\n",
"\n",
"* JAX: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"* JAX: Users should provide the `attention_mask` tensor of shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"\n",
"\n",
"**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
"**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
"Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper, Ada, as well as 8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your JAX code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n",
"\n",
"This guide shows how to start using Transformer Engine with JAX. Similar tutorial for pyTorch is available [here](quickstart.ipynb).\n",
"We recommend you to try understanding the basics of JAX first, using these resources:\n",
"\n",
"- Thinking in JAX: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html\n",
"<small>_This is based upon the GPT decoder layer with causal masking, which prevents each position from attending to future positions._</small>\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We build a basic Transformer layer using regular Flax modules. This will be our baseline for later comparisons with Transformer Engine.\n",
"\n",
"</div>\n",
"\n",
"Let's start with creating the transformer layer using plain [FLAX Linen](https://flax.readthedocs.io/en/stable/) . Figure 1 shows the overall structure.\n",
"<figcaption> Figure 1: Structure of a GPT decoder layer.</figcaption>\n",
"</figure>\n",
"\n",
"We construct the components as follows:\n",
"\n",
"- `LayerNorm`: `nn.LayerNorm` (Flax)\n",
"- `QKV Projection`: `nn.Dense` (conceptually there are three seperate `Dense` layers for Q, K, and V separately, but we fuse them together into a single `Dense` layer that is three times larger)\n",
"- `MLP`: `FlaxMLP` implemented using `nn.Dense` and `nn.gelu`\n",
"\n",
"Over the course of this tutorial we will use a few modules and helper functions defined in [quickstart_jax_utils.py](quickstart_jax_utils.py). Putting it all together: \n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d5284a38",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from flax import linen as nn\n",
"import quickstart_jax_utils as utils\n",
"from typing import Optional"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a4d1cfdc",
"metadata": {},
"outputs": [],
"source": [
"class FlaxMLP(nn.Module):\n",
" \"\"\"Feed-forward network in Transformer layer\n",
"Now that we have a basic Transformer layer in Flax, let's use Transformer Engine to speed up the training. The following examples show how to use TE modules.\n",
"\n",
"</div>\n",
"\n",
"As a reminder, the FlaxTransformerLayer above used:\n",
"\n",
"- `nn.LayerNorm`: Flax LayerNorm\n",
"- `nn.Dense`: Flax Dense layer for QKV projection \n",
"- `nn.Dense`: Flax Dense layer for projection\n",
"- `nn.Dropout`: Flax Dropout\n",
"- `FlaxMLP`: Custom MLP implemented from `nn.Dense`\n",
"\n",
"Below we show how to use Transformer Engine Flax modules for better performance:\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bed20d6b",
"metadata": {},
"outputs": [],
"source": [
"import transformer_engine.jax as te\n",
"import transformer_engine.jax.flax as te_flax"
]
},
{
"cell_type": "markdown",
"id": "f28cb444",
"metadata": {},
"source": [
"TE provides a set of Flax Linen modules that can be used to build Transformer layers. The simplest of the provided modules are the `DenseGeneral ` and `LayerNorm` layers, which we can use instead of `flax.linen.Dense` and ` flax.linen.LayerNorm`. Let's modify our `FlaxTransformerLayer`:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "56105579",
"metadata": {},
"outputs": [],
"source": [
"from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention\n",
" x = attention(q_reshaped, k_reshaped, v_reshaped, mask=attention_mask, deterministic=deterministic)\n",
"\n",
" x = res + x\n",
"\n",
" # Second residual connection\n",
" res = x\n",
" x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
"\n",
" # MLP\n",
" mlp = TEUnfusedMLP(\n",
" hidden_size=self.hidden_size,\n",
" ffn_hidden_size=self.ffn_hidden_size\n",
" )\n",
"\n",
" x = mlp(x, deterministic=deterministic)\n",
"\n",
" return x + res"
]
},
{
"cell_type": "markdown",
"id": "a76911ac",
"metadata": {},
"source": [
"Testing performance of the model, using `DenseGeneral`, `LayerNorm` and activation from TE, while keeping Flax's `MultiHeadDotProductAttention` the same as the first simple Transformer in JAX implementation. To read more about this implementation from Flax, you can refer to this documentation: https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html"
"We configure a TE module to perform compute in FP8.\n",
"\n",
"</div>\n",
"\n",
"Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/jax.rst#transformer_engine.jax.fp8_autocast) context manager. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options.\n",
"When using FP8, the model **must be initialized within the `autocast` context**. This creates a special collection called `fp8_metas` that contains scaling factors and other metadata required for FP8 computation. If you initialize a model outside of `autocast` and then try to use it with FP8, you will get a `ScopeCollectionNotFound` error because the `fp8_metas` collection was never created.\n",
"\n",
"</div>"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c2eee376",
"metadata": {},
"outputs": [],
"source": [
"from transformer_engine.common.recipe import Format, DelayedScaling\n",
"We optimize the example Transformer layer with TE modules for fused operations.\n",
"\n",
"</div>\n",
"\n",
"The `DenseGeneral` layer is enough to build any Transformer model and it enables usage of the Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations such as kernel fusions in mixed-precision recipes, increasing the achievable speedup.\n",
"\n",
"Transformer Engine therefore provides coarser modules that span multiple layers:\n",
"\n",
"* `LayerNormDenseGeneral`\n",
"* `LayerNormMLP`\n",
"* `TransformerLayer`\n",
"\n",
"To see a complete list of all the functions TE Flax support, you can view it here: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#modules\n",
"\n",
"Building a third iteration of our Transformer layer with `LayerNormDenseGeneral` and `LayerNormMLP`:"
"Similar to the unnfused model, we also compare the performance of fused model when using Flax's MultiheadDotProductAttention implementation and TE's."
"For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n",
"For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n",
"\n",
"\n",
"In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n",
"In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n",
"\n",
"\n",
"This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n",
"This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n",
@@ -28,7 +28,7 @@ on `NVIDIA GPU Cloud <https://ngc.nvidia.com>`_.
...
@@ -28,7 +28,7 @@ on `NVIDIA GPU Cloud <https://ngc.nvidia.com>`_.
pip - from PyPI
pip - from PyPI
-----------------------
---------------
Transformer Engine can be directly installed from `our PyPI <https://pypi.org/project/transformer-engine/>`_, e.g.
Transformer Engine can be directly installed from `our PyPI <https://pypi.org/project/transformer-engine/>`_, e.g.
...
@@ -47,7 +47,7 @@ The core package from Transformer Engine (without any framework extensions) can
...
@@ -47,7 +47,7 @@ The core package from Transformer Engine (without any framework extensions) can
By default, this will install the core library compiled for CUDA 12. The cuda major version can be specified by modified the extra dependency to `core_cu12` or `core_cu13`.
By default, this will install the core library compiled for CUDA 12. The cuda major version can be specified by modified the extra dependency to `core_cu12` or `core_cu13`.