Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Distributed training
===================
====================
Nvidia-Pytorch-Inspect with Transformer Engine supports multi-GPU training. This guide describes how to run it and how the supported features work in the distributed setting.
...
...
@@ -14,7 +14,8 @@ To use precision debug tools in multi-GPU training, one needs to:
2. If one wants to log stats, one may want to invoke ``debug_api.set_tensor_reduction_group`` with a proper reduction group.
Behavior of the features
-----------------------
------------------------
In a distributed setting, **DisableFP8GEMM** and **DisableFP8Layer** function similarly to the single-GPU case, with no notable differences.
...
...
@@ -28,7 +29,8 @@ In a distributed setting, **DisableFP8GEMM** and **DisableFP8Layer** function si
Logging-related features are more complex and will be discussed further in the next sections.
Reduction groups
--------------
----------------
In setups with tensor, data, or pipeline parallelism, some tensors are distributed across multiple GPUs, requiring a reduction operation to compute statistics for these tensors.
...
...
@@ -65,7 +67,8 @@ Below, we illustrate configurations for a 4-node setup with tensor parallelism s
Microbatching
-----------
-------------
Let's dive into how statistics collection works with microbatching. By microbatching, we mean invoking multiple ``forward()`` calls for each ``debug_api.step()``. The behavior is as follows:
...
...
@@ -73,7 +76,7 @@ Let's dive into how statistics collection works with microbatching. By microbatc
- For other tensors, the stats are accumulated.
Logging to files and TensorBoard
------------------------------
--------------------------------
In a single-node setup with ``default_logging_enabled=True``, all logs are saved by default to ``log_dir/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log``. In multi-GPU training, each node writes its reduced statistics to its unique file, named ``log_dir/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-i.log`` for rank i. Because these logs contain reduced statistics, the logged values are identical for all nodes within a reduction group.
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
.. _environment_variables:
Environment Variables
=====================
This document describes the environment variables used by Transformer Engine. They provide an alternate method to alter Transformer Engine's behavior during build and runtime, but are less rigorously maintained compared to the API and may be subject to change.
Build-Time Environment Variables
---------------------------------
These environment variables control the build and compilation process of Transformer Engine.
Build Configuration
^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_BUILD_DEBUG
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable debug build mode. When set to ``1``, the build includes debug symbols (``-g``) and disables optimizations.
.. envvar:: NVTE_BUILD_MAX_JOBS
:Type: ``int``
:Default: Maximum available
:Description: Number of parallel jobs to use during the build process. If not set, the system will use the maximum available parallel jobs. Also respects the standard ``MAX_JOBS`` environment variable.
.. envvar:: NVTE_BUILD_THREADS_PER_JOB
:Type: ``int``
:Default: ``1``
:Description: Number of threads to use per parallel build job. This is passed to the CUDA compiler via the ``--threads`` flag.
.. envvar:: NVTE_FRAMEWORK
:Type: ``str``
:Default: Auto-detected
:Description: Comma-separated list of frameworks to build support for (``pytorch``, ``jax``, ``all``, or ``none``). If not specified, automatically detects installed frameworks.
.. envvar:: NVTE_USE_CCACHE
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable ccache for faster recompilation. When set to ``1``, uses ccache as a compiler launcher for both C++ and CUDA compilation.
.. envvar:: NVTE_CCACHE_BIN
:Type: ``str``
:Default: ``ccache``
:Description: Path to the ccache binary. Only used when :envvar:`NVTE_USE_CCACHE` is set to ``1``.
.. envvar:: NVTE_CMAKE_BUILD_DIR
:Type: ``str``
:Default: None
:Description: Path to the CMake build directory for incremental builds. If set, CMake will use this directory for build artifacts.
.. envvar:: NVTE_RELEASE_BUILD
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable release build mode. When set to ``1``, prepares the build for distribution (e.g., PyPI wheel). This affects library installation paths and build tool management.
.. envvar:: NVTE_PROJECT_BUILDING
:Type: ``int`` (0 or 1)
:Default: Not set
:Description: Internal flag set to ``1`` during the build process to indicate that the project is being built. Not intended for external use.
Optional Dependencies
^^^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_UB_WITH_MPI
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable MPI support for userbuffers. When set to ``1``, requires ``MPI_HOME`` to be set to the MPI installation directory.
.. envvar:: NVTE_ENABLE_NVSHMEM
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable NVSHMEM support. When set to ``1``, requires ``NVSHMEM_HOME`` to be set to the NVSHMEM installation directory.
.. envvar:: NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
:Type: CMake option
:Default: ``OFF``
:Description: Compile activation kernels (GELU, ReLU, SwiGLU) with the ``--use_fast_math`` CUDA compiler flag for improved performance at the cost of some precision.
CUDA Configuration
^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_CUDA_ARCHS
:Type: ``str``
:Default: Auto-detected based on CUDA version
:Description: Semicolon-separated list of CUDA compute architectures to compile for (e.g., ``"80;90"`` for A100 and H100, or ``"75;80;89;90"``). If not set, automatically determined based on the installed CUDA Toolkit version. CUDA 13.0+ defaults to ``"75;80;89;90;100;120"``, CUDA 12.8+ defaults to ``"70;80;89;90;100;120"``, and earlier versions default to ``"70;80;89;90"``. Setting this can significantly reduce build time and binary size by targeting only the GPU architectures you need.
.. envvar:: NVTE_CUDA_INCLUDE_DIR
:Type: ``str``
:Default: Auto-detected
:Description: Path to CUDA include directory containing ``cuda_runtime.h``. If not set, Transformer Engine searches in common locations (``CUDA_HOME``, ``CUDA_DIR``, ``/usr/local/cuda``). This is used for NVRTC kernel compilation.
Runtime Environment Variables
------------------------------
These environment variables control the behavior of Transformer Engine during execution.
Attention Backend Selection
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_FLASH_ATTN
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Enable or disable FlashAttention backend for DotProductAttention. When set to ``0``, FlashAttention will not be used.
.. envvar:: NVTE_FUSED_ATTN
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Enable or disable FusedAttention backend (cuDNN-based) for DotProductAttention. When set to ``0``, FusedAttention will not be used.
.. envvar:: NVTE_UNFUSED_ATTN
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Enable or disable UnfusedDotProductAttention backend (native PyTorch). When set to ``0``, UnfusedDotProductAttention will not be used.
.. envvar:: NVTE_FUSED_ATTN_BACKEND
:Type: ``int`` (0, 1, or 2)
:Default: Auto-selected
:Description: Force a specific FusedAttention backend. ``0`` = F16_max512_seqlen (cuDNN, ≤512 seq len), ``1`` = F16_arbitrary_seqlen (cuDNN, any seq len), ``2`` = FP8 backend. If not set, the backend is automatically selected based on the input configuration.
.. envvar:: NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT
:Type: ``int`` (0 or 1)
:Default: Auto-determined
:Description: Control workspace-related optimizations in FusedAttention. ``0`` disables optimizations, ``1`` enables them. These optimizations trade memory for performance. When unset, Transformer Engine determines the code path based on internal logic. For deterministic behavior with cuDNN ≥8.9.5 and <9.0.0, this is automatically set to ``1``.
.. envvar:: NVTE_FUSED_ATTN_USE_FAv2_BWD
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: When using FusedAttention, use FlashAttention-2 implementation for the backward pass instead of the cuDNN implementation. This can be useful due to performance differences between various versions of flash-attn and FusedAttention.
.. envvar:: NVTE_ALLOW_NONDETERMINISTIC_ALGO
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Allow non-deterministic algorithms for Transformer Engine execution. When set to ``0``, only deterministic algorithms are allowed. This is relevant for both PyTorch and JAX attention implementations.
.. envvar:: NVTE_FUSED_RING_ATTENTION_USE_SCAN
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: **(JAX only)** Use scan loop for ring attention implementation. When set to ``1``, the fused ring attention will use a scan-based iteration approach.
.. envvar:: NVTE_APPLY_QK_LAYER_SCALING
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Apply QK layer scaling in UnfusedDotProductAttention. This is an FP16 training trick required for certain GPT-like models. When set to ``1`` and a layer number is provided, the softmax scale is divided by the layer number, and the layer number is used as the softmax scale during the softmax operation. Only effective when using FP16 dtype and when the layer number is specified.
Context Parallelism
^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_BATCH_MHA_P2P_COMM
:Type: ``int`` (0 or 1)
:Default: ``0`` (or auto-enabled for pre-Blackwell GPUs with CP size 2)
:Description: Use batched P2P communication (``batch_isend_irecv``) for KV exchange in context parallel MultiheadAttention. When enabled, send and receive operations are batched together, which can improve communication efficiency. This is automatically enabled for devices with compute capability < 10.0 (pre-Blackwell GPUs) when context parallel size is 2. Setting this to ``1`` forces batched P2P communication regardless of device architecture.
FP8 Configuration
^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_UNFUSED_FP8_UPDATE
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Use unfused kernel for FP8 amax and scale updates. When set to ``1``, amax and scale updates are computed using separate unfused kernels instead of fused operations.
.. envvar:: NVTE_FP8_DPA_BWD
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Enable FP8 in the backward pass of DotProductAttention. ``1`` = FP8 forward and backward, ``0`` = FP8 forward and FP16/BF16 backward.
.. envvar:: NVTE_DPA_FP8CS_O_in_F16
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: For Float8CurrentScaling in DotProductAttention, use FP16/BF16 for the output tensor in the backward pass. ``1`` = use F16/BF16 output in backward, ``0`` = use FP8 output in backward.
.. envvar:: NVTE_DPA_FP8_RECIPE
:Type: ``str``
:Default: Empty (use same as linear layers)
:Description: Override FP8 recipe for DotProductAttention layers. Valid values: ``"F16"`` (disable FP8), ``"DelayedScaling"``, or ``"Float8CurrentScaling"``. This allows using different FP8 recipes for attention vs. linear layers.
.. envvar:: NVTE_DPA_FP8_RECIPE_DPA
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable FP8 in DotProductAttention when using :envvar:`NVTE_DPA_FP8_RECIPE`. When set to ``1``, the DotProductAttention layer will use the FP8 recipe specified by :envvar:`NVTE_DPA_FP8_RECIPE`. This provides fine-grained control over which attention components use FP8.
.. envvar:: NVTE_DPA_FP8_RECIPE_MHA
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable FP8 in MultiheadAttention (MHA) when using :envvar:`NVTE_DPA_FP8_RECIPE`. When set to ``1``, the MultiheadAttention QKV and output projection layers will use the FP8 recipe specified by :envvar:`NVTE_DPA_FP8_RECIPE`. This provides fine-grained control over which attention components use FP8.
.. envvar:: NVTE_DPA_FP8_FORMAT
:Type: ``str``
:Default: ``"HYBRID"``
:Description: FP8 format for DotProductAttention when switching recipes. Valid values: ``"HYBRID"``, ``"E4M3"``, ``"E5M2"``. Only used when :envvar:`NVTE_DPA_FP8_RECIPE` is set.
.. envvar:: NVTE_DPA_FP8DS_AMAX_ALGO
:Type: ``str``
:Default: ``"most_recent"``
:Description: Amax computation algorithm for DelayedScaling recipe in DotProductAttention. Valid values: ``"most_recent"``, ``"max"``. Only used when :envvar:`NVTE_DPA_FP8_RECIPE` is set to ``"DelayedScaling"``.
.. envvar:: NVTE_DPA_FP8DS_AMAX_HISTLEN
:Type: ``int``
:Default: ``1``
:Description: Amax history length for DelayedScaling recipe in DotProductAttention. Only used when :envvar:`NVTE_DPA_FP8_RECIPE` is set to ``"DelayedScaling"``.
.. envvar:: NVTE_DPA_FP8DS_REDUCE_AMAX
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Reduce amax across distributed ranks for DelayedScaling recipe in DotProductAttention. Only used when :envvar:`NVTE_DPA_FP8_RECIPE` is set to ``"DelayedScaling"``.
.. envvar:: NVTE_UnfusedDPA_Emulate_FP8
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Allow FP8 emulation in UnfusedDotProductAttention. When set to ``1``, UnfusedDotProductAttention can emulate FP8 operations using FP16/BF16 computation.
Kernel Configuration
^^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_USE_FAST_MATH
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable fast math optimizations in runtime-compiled (NVRTC) kernels. This trades numerical accuracy for performance. These optimizations are experimental and inconsistently implemented.
.. envvar:: NVTE_DISABLE_NVRTC
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Disable NVRTC (CUDA Runtime Compilation) support. When set to ``1``, runtime kernel compilation is disabled. This can be useful in environments where NVRTC is not available or not desired.
.. envvar:: NVTE_USE_CUTLASS_GROUPED_GEMM
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Use CUTLASS implementation for grouped GEMM operations instead of cuBLAS. When set to ``1``, enables CUTLASS grouped GEMM kernels, which may provide better performance for certain workloads on Hopper (SM90) GPUs.
:Description: Emit a warning when falling back from CUTLASS to cuBLAS for grouped GEMM operations.
Torch Compilation and Fusion
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_TORCH_COMPILE
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Enable PyTorch 2.x ``torch.compile`` support for compatible Transformer Engine operations. When set to ``0``, disables compilation support and uses regular PyTorch eager mode.
.. envvar:: NVTE_BIAS_GELU_NVFUSION
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Enable GELU fusion with bias using NVFusion in PyTorch. When set to ``0``, uses separate bias and GELU operations.
.. envvar:: NVTE_BIAS_DROPOUT_FUSION
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Enable fusion of bias and dropout operations. When set to ``0``, bias and dropout are computed separately.
LayerNorm/RMSNorm SM Margins
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_FWD_LAYERNORM_SM_MARGIN
:Type: ``int``
:Default: ``0``
:Description: Number of SMs (Streaming Multiprocessors) to reserve (not use) during forward LayerNorm/RMSNorm operations. This can be used to control resource allocation and overlap computation with communication.
.. envvar:: NVTE_BWD_LAYERNORM_SM_MARGIN
:Type: ``int``
:Default: ``0``
:Description: Number of SMs to reserve during backward LayerNorm/RMSNorm operations.
.. envvar:: NVTE_INF_LAYERNORM_SM_MARGIN
:Type: ``int``
:Default: ``0``
:Description: Number of SMs to reserve during inference LayerNorm/RMSNorm operations.
GEMM Configuration
^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_EXT_MARGIN_SM
:Type: ``int``
:Default: Total SM count
:Description: External SM margin for GEMM operations. Specifies the number of SMs to use for GEMM operations. The actual number of SMs used is ``sm_count - NVTE_EXT_MARGIN_SM``.
.. envvar:: NVTE_AG_P2P_MULTI_ATOMIC
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable multi-atomic mode for AllGather with atomic GEMM using P2P communication. When set to ``1``, uses ``userbuffers_sendrecv_multiatomic`` for communication during atomic GEMM overlap with AllGather operations. This disables copy engine (CE) usage and enables push mode for userbuffers. This is an advanced optimization for tensor-parallel communication-computation overlap.
CPU Offloading
^^^^^^^^^^^^^^
.. envvar:: NVTE_CPU_OFFLOAD_V1
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable legacy version of CPU offloading implementation.
Debugging and Profiling
^^^^^^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_DEBUG
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable debug mode. When set to ``1``, enables verbose debug output and additional checks in attention operations.
.. envvar:: NVTE_DEBUG_LEVEL
:Type: ``int`` (0, 1, or 2)
:Default: ``0``
:Description: Debug verbosity level. Higher values enable more verbose debug output. Only effective when :envvar:`NVTE_DEBUG` is set to ``1``.
.. envvar:: NVTE_PRINT_LAYER_NUMBER
:Type: ``int``
:Default: ``1``
:Description: Layer number to print debug information for during attention operations.
.. envvar:: NVTE_PRINT_RANK
:Type: ``int``
:Default: ``0``
:Description: Distributed rank to print debug information for during attention operations.
.. envvar:: NVTE_NVTX_ENABLED
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable NVTX (NVIDIA Tools Extension) range profiling for Transformer Engine operations. When set to ``1``, NVTX markers are added to operations for profiling with NVIDIA Nsight Systems.
.. envvar:: NVTE_DEBUG_NUMERICS
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: **(JAX only)** Enable verbose printing of tensor numerics for debugging purposes.
Testing
^^^^^^^
.. envvar:: NVTE_TEST_NVINSPECT_ENABLED
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable NVInspect integration for testing. When set to ``1``, enables the NVInspect debugging API for numerical analysis during tests.
.. envvar:: NVTE_TEST_NVINSPECT_CONFIG_FILE
:Type: ``str``
:Default: None
:Description: Path to NVInspect configuration file. Required when :envvar:`NVTE_TEST_NVINSPECT_ENABLED` is set to ``1``.
.. envvar:: NVTE_TEST_NVINSPECT_FEATURE_DIRS
:Type: ``str``
:Default: None
:Description: Comma-separated list of directories containing NVInspect features. Required when :envvar:`NVTE_TEST_NVINSPECT_ENABLED` is set to ``1``.
.. envvar:: NVTE_TEST_ARTIFACTS_DIR
:Type: ``str``
:Default: System temp directory
:Description: Directory for storing test artifacts (e.g., generated ONNX models).
ONNX Export
^^^^^^^^^^^
.. envvar:: NVTE_ONNX_KVCACHE_MAX_SEQ_LEN
:Type: ``int``
:Default: ``128``
:Description: Maximum sequence length for KV cache during ONNX export. This is used for attention masking in exported ONNX models.
JAX-Specific Variables
^^^^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_JAX_CUSTOM_CALLS
:Type: ``str``
:Default: None
:Description: Control which JAX custom call primitives are enabled or disabled. Format: ``"true"`` (enable all), ``"false"`` (disable all), or comma-separated key-value pairs like ``"GemmPrimitive=false,DBiasQuantizePrimitive=true"``. This provides fine-grained control over which operations use custom CUDA kernels vs. JAX native implementations.
.. envvar:: NVTE_JAX_CUSTOM_CALLS_RE
:Type: ``str``
:Default: None
:Description: **Deprecated** (use :envvar:`NVTE_JAX_CUSTOM_CALLS` instead). Regex pattern to match primitive names for enabling/disabling. Example: ``"DBiasQuantizePrimitive"`` or ``"^(?!DBiasQuantizePrimitive$).+$"``.
.. envvar:: NVTE_JAX_UNITTEST_LEVEL
:Type: ``str``
:Default: None
:Description: Test level for JAX unit tests (``"L0"``, ``"L1"``, ``"L2"``). Used internally by the test suite.
Examples
--------
Building with Debug Symbols
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
export NVTE_BUILD_DEBUG=1
export NVTE_USE_CCACHE=1
pip install -e .
Using Specific Attention Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
# Use only FlashAttention, disable FusedAttention
export NVTE_FLASH_ATTN=1
export NVTE_FUSED_ATTN=0
python train.py
Configuring FP8 for Attention
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
# Use DelayedScaling for attention, CurrentScaling for linear layers
"This guide is a follow-up to the discussion in the [quickstart guide](quickstart.ipynb). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). "
"This guide is a follow-up to the discussion in the [Getting Started guide](../getting_started/index.rst). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). "
]
},
{
...
...
@@ -100,7 +100,7 @@
"\n",
"</div>\n",
"\n",
"A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their $\\text{sequence_length} \\times \\text{batch_size} \\times \\text{hidden_size}$ activation tensors. The most common approach is data parallelism, which distributes along the $\\text{batch_size}$ dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the $\\text{hidden_size}$ dimension. This allows us to scale past the limits of data parallelism (typically $\\text{hidden_size} > \\text{batch_size}$) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Finally, sequence parallelism distributes along the $\\text{sequence_length}$ dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see [this paper](https://arxiv.org/pdf/2205.05198.pdf).\n",
"A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their $\\text{sequence_length} \\cdot \\text{batch_size} \\cdot \\text{hidden_size}$ activation tensors. The most common approach is data parallelism, which distributes along the $\\text{batch_size}$ dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the $\\text{hidden_size}$ dimension. This allows us to scale past the limits of data parallelism (typically $\\text{hidden_size} > \\text{batch_size}$) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Finally, sequence parallelism distributes along the $\\text{sequence_length}$ dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see [this paper](https://arxiv.org/pdf/2205.05198.pdf).\n",
"\n",
"To show this in action, let's first initialize NCCL with a trivial process group:"
]
...
...
@@ -131,7 +131,7 @@
"id": "1f2b80d0",
"metadata": {},
"source": [
"We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\times \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n",
"We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\cdot \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n",
"\n",
"Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n",
"* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
" - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
" - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
" - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor of shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors of shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
"\n",
"\n",
"* JAX: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"* JAX: Users should provide the `attention_mask` tensor of shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"\n",
"**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
"For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n",
"\n",
"In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n",
"In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n",
"\n",
"This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n",
"# JAX: Integrating TE into an existing framework\n",
"\n",
"This tutorial will cover how to integrate TransformerEngine into an existing JAX model framework, such as [MaxText's TE integration](https://github.com/AI-Hypercomputer/maxtext/blob/ed517cf80d9aa81f76e236c5516dacebfe39e96d/src/MaxText/layers/quantizations.py#L753) or your own model framework. \n"
]
},
{
"cell_type": "markdown",
"id": "b36876bb",
"metadata": {},
"source": [
"Let's start with a standard JAX+Flax Transformer layer"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d5284a38",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from flax import linen as nn\n",
"import quickstart_jax_utils as utils\n",
"from typing import Optional"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a4d1cfdc",
"metadata": {},
"outputs": [],
"source": [
"class FlaxMLP(nn.Module):\n",
" \"\"\"Feed-forward network in Transformer layer\n",
" # q, k, v now have shape [batch, seq_len, num_heads, kv_channels]\n",
" # which is the correct format for dot_product_attention\n",
" \n",
" # Apply dot product attention\n",
" # Note: dot_product_attention expects mask to be broadcastable to \n",
" # [batch, num_heads, q_length, kv_length], but attention_mask from \n",
" # nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]\n",
" \n",
" # Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)\n",
" dropout_rng = None\n",
" if not deterministic and self.attention_dropout > 0:\n",
" dropout_rng = self.make_rng('dropout')\n",
" \n",
" # See quickstart_jax.ipynb for details on using TE's faster fused attention\n",
" x = nn.dot_product_attention(\n",
" query=q,\n",
" key=k,\n",
" value=v,\n",
" mask=attention_mask,\n",
" dropout_rng=dropout_rng,\n",
" dropout_rate=self.attention_dropout,\n",
" deterministic=deterministic,\n",
" broadcast_dropout=True,\n",
" )\n",
" \n",
" # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n",
" x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n",
"\n",
" # Output projection\n",
" x = nn.Dense(features=self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n",
" \n",
" x = res + x\n",
" \n",
" # Second residual connection\n",
" res = x\n",
" x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
" \n",
" # MLP\n",
" mlp = FlaxMLP(\n",
" hidden_size=self.hidden_size,\n",
" ffn_hidden_size=self.ffn_hidden_size,\n",
" dot_general_cls=self.dot_general_cls,\n",
" )\n",
" x = mlp(x)\n",
" \n",
" return x + res\n"
]
},
{
"cell_type": "markdown",
"id": "db16bf70",
"metadata": {},
"source": [
"We've exposed `dot_general_cls` here so we can test out different GEMM implementations later. By default, Flax's `nn.Dense` will use JAX's GEMM `jax.lax.dot_general` when `dot_general` is `None`."
]
},
{
"cell_type": "markdown",
"id": "fbc3510b",
"metadata": {},
"source": [
"## Testing Performance\n",
"\n",
"Now let's test the performance of our FlaxTransformerLayer:\n"
"TransformerEngine/JAX is currently using Flax Linen. However, it is easily compatible with Flax NNX or Haiku.\n",
"* [Use Flax NNX and Linen together](https://flax.readthedocs.io/en/latest/guides/bridge_guide.html)\n",
"* [Haiku and Flax interop](https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html)\n",
"\n",
"Additionally, with the tutorial below, no model parameters need to be managed by TransformerEngine. You can keep all your existing model parameters, initialization, and sharding the same. The only change required is to call TE's dot_general_cls instead of the default Dense dot_general implementation. TE's dot_general_cls is a small module that performs a quantized dense VJP and stores some small recipe-specific state."
]
},
{
"cell_type": "markdown",
"id": "4477d4e9",
"metadata": {},
"source": [
"Now we'll select a recipe. `DelayedScaling` and `CurrentScaling` use per-tensor scaling and are supported on Hopper and Blackwell. `MXFP8BlockScaling` and `NVFP4BlockScaling` use block scaling or a combination of both per-tensor and block scaling and are supported on Blackwell.\n",
"\n",
"If you would like to customize the recipe further, various options can be changed by passing args to the recipe's constructor."
" # The NVFP4 recipe requires a Flax RNG for stochastic rounding\n",
" rngs['sr_rng'] = jax.random.PRNGKey(0)\n"
]
},
{
"cell_type": "markdown",
"id": "c8769655",
"metadata": {},
"source": [
"Now using this quantized dense in our model is as simple as passing in `dot_general_fn=te_dot_general`. Let's try it out!\n",
"\n",
"<div class=\"alert alert-warning\">\n",
"\n",
"<b>Important: Remat Policy</b>\n",
"\n",
"TE's quantization uses specialized TE quantized GEMM primitives. If you are using any built-in JAX checkpoint policies that look for JAX GEMMs (dots), such as `jax.checkpoint_policies.checkpoint_dots`, please replace the policy with `transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms` or similar policies to ensure TE's quantized GEMM primitives are checkpointed correctly.\n",
"\n",
"If this is not performed, TE GEMMs will be rematerialized introducing an incorrect performance comparison.\n",
"print(f\"Additional state: {jax.tree_util.tree_map(lambda x: x.shape, {k: v for k, v in var_collect.items() if k != 'params'})}\")"
]
},
{
"cell_type": "markdown",
"id": "abe27237",
"metadata": {},
"source": [
"If using a recipe that stores additional state, such as `DelayedScaling`, you'll see this additional state stored as Flax variables. It is important to maintain and pass the whole state of Flax variables `var_collect` across training steps, not just the model params, for proper usage of stateful recipes like `DelayedScaling`.\n",
"\n",
"For example, above inside `Additional state: ` you'll see the `amax_history` of each quantization which is used to compute the per-tensor scale in the `DelayedScaling` recipe."
]
},
{
"cell_type": "markdown",
"id": "5ab72935",
"metadata": {},
"source": [
"The reason we need `te_dot_general_cls` as a Flax module instead of a module-less function like `jax.lax.dot_general` is for some quantization recipes to track internal state separate from model parameters.\n",
"\n",
"Flax modules can manage 3 things:\n",
"1. Model parameters/weights, e.g. your Dense \"kernel\", \"bias\", etc.\n",
"2. RNGs for dropout, stochastic rounding, etc.\n",
"3. Flax variables. These are additional state variables that are used across training steps but are distinct from model params in that you don't take gradients or optimize them. Currently, we only use this for DelayedScaling's amax_history state\n",
"\n",
"With the simplest quantization integration shown in this tutorial, we want users to keep their existing model param setup so they don't need to worry about preserving the sharding, init distribution, etc.. So we don't need point 1 since we don't do model param creation in this codepath with dot_general_cls, but we still do need `te_dot_general_cls()` to produce a Flax module since we potentially need to do points 2 or 3 which need to be in a Flax module."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "3b6b344b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape: (4, 2048, 4096)\n",
"Output shape: (4, 2048, 4096)\n",
"Output dtype: float32\n",
"Forward pass completed successfully!\n"
]
}
],
"source": [
"# Example usage of forward pass\n",
"y = flax_transformer.apply(var_collect, x, attention_mask=None, deterministic=True, rngs=rngs)\n",