Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Debug features
==========
==============
.. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats
.. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats
......
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
......
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Distributed training
===================
====================
Nvidia-Pytorch-Inspect with Transformer Engine supports multi-GPU training. This guide describes how to run it and how the supported features work in the distributed setting.
......@@ -14,7 +14,8 @@ To use precision debug tools in multi-GPU training, one needs to:
2. If one wants to log stats, one may want to invoke ``debug_api.set_tensor_reduction_group`` with a proper reduction group.
Behavior of the features
-----------------------
------------------------
In a distributed setting, **DisableFP8GEMM** and **DisableFP8Layer** function similarly to the single-GPU case, with no notable differences.
......@@ -28,7 +29,8 @@ In a distributed setting, **DisableFP8GEMM** and **DisableFP8Layer** function si
Logging-related features are more complex and will be discussed further in the next sections.
Reduction groups
--------------
----------------
In setups with tensor, data, or pipeline parallelism, some tensors are distributed across multiple GPUs, requiring a reduction operation to compute statistics for these tensors.
......@@ -65,7 +67,8 @@ Below, we illustrate configurations for a 4-node setup with tensor parallelism s
Microbatching
-----------
-------------
Let's dive into how statistics collection works with microbatching. By microbatching, we mean invoking multiple ``forward()`` calls for each ``debug_api.step()``. The behavior is as follows:
......@@ -73,7 +76,7 @@ Let's dive into how statistics collection works with microbatching. By microbatc
- For other tensors, the stats are accumulated.
Logging to files and TensorBoard
------------------------------
--------------------------------
In a single-node setup with ``default_logging_enabled=True``, all logs are saved by default to ``log_dir/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log``. In multi-GPU training, each node writes its reduced statistics to its unique file, named ``log_dir/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-i.log`` for rank i. Because these logs contain reduced statistics, the logged values are identical for all nodes within a reduction group.
......
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
API
============
===
.. toctree::
:caption: Precision debug tools API
......
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
.. _environment_variables:
Environment Variables
=====================
This document describes the environment variables used by Transformer Engine. They provide an alternate method to alter Transformer Engine's behavior during build and runtime, but are less rigorously maintained compared to the API and may be subject to change.
Build-Time Environment Variables
---------------------------------
These environment variables control the build and compilation process of Transformer Engine.
Build Configuration
^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_BUILD_DEBUG
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable debug build mode. When set to ``1``, the build includes debug symbols (``-g``) and disables optimizations.
.. envvar:: NVTE_BUILD_MAX_JOBS
:Type: ``int``
:Default: Maximum available
:Description: Number of parallel jobs to use during the build process. If not set, the system will use the maximum available parallel jobs. Also respects the standard ``MAX_JOBS`` environment variable.
.. envvar:: NVTE_BUILD_THREADS_PER_JOB
:Type: ``int``
:Default: ``1``
:Description: Number of threads to use per parallel build job. This is passed to the CUDA compiler via the ``--threads`` flag.
.. envvar:: NVTE_FRAMEWORK
:Type: ``str``
:Default: Auto-detected
:Description: Comma-separated list of frameworks to build support for (``pytorch``, ``jax``, ``all``, or ``none``). If not specified, automatically detects installed frameworks.
.. envvar:: NVTE_USE_CCACHE
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable ccache for faster recompilation. When set to ``1``, uses ccache as a compiler launcher for both C++ and CUDA compilation.
.. envvar:: NVTE_CCACHE_BIN
:Type: ``str``
:Default: ``ccache``
:Description: Path to the ccache binary. Only used when :envvar:`NVTE_USE_CCACHE` is set to ``1``.
.. envvar:: NVTE_CMAKE_BUILD_DIR
:Type: ``str``
:Default: None
:Description: Path to the CMake build directory for incremental builds. If set, CMake will use this directory for build artifacts.
.. envvar:: NVTE_RELEASE_BUILD
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable release build mode. When set to ``1``, prepares the build for distribution (e.g., PyPI wheel). This affects library installation paths and build tool management.
.. envvar:: NVTE_PROJECT_BUILDING
:Type: ``int`` (0 or 1)
:Default: Not set
:Description: Internal flag set to ``1`` during the build process to indicate that the project is being built. Not intended for external use.
Optional Dependencies
^^^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_UB_WITH_MPI
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable MPI support for userbuffers. When set to ``1``, requires ``MPI_HOME`` to be set to the MPI installation directory.
.. envvar:: NVTE_ENABLE_NVSHMEM
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable NVSHMEM support. When set to ``1``, requires ``NVSHMEM_HOME`` to be set to the NVSHMEM installation directory.
.. envvar:: NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
:Type: CMake option
:Default: ``OFF``
:Description: Compile activation kernels (GELU, ReLU, SwiGLU) with the ``--use_fast_math`` CUDA compiler flag for improved performance at the cost of some precision.
CUDA Configuration
^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_CUDA_ARCHS
:Type: ``str``
:Default: Auto-detected based on CUDA version
:Description: Semicolon-separated list of CUDA compute architectures to compile for (e.g., ``"80;90"`` for A100 and H100, or ``"75;80;89;90"``). If not set, automatically determined based on the installed CUDA Toolkit version. CUDA 13.0+ defaults to ``"75;80;89;90;100;120"``, CUDA 12.8+ defaults to ``"70;80;89;90;100;120"``, and earlier versions default to ``"70;80;89;90"``. Setting this can significantly reduce build time and binary size by targeting only the GPU architectures you need.
.. envvar:: NVTE_CUDA_INCLUDE_DIR
:Type: ``str``
:Default: Auto-detected
:Description: Path to CUDA include directory containing ``cuda_runtime.h``. If not set, Transformer Engine searches in common locations (``CUDA_HOME``, ``CUDA_DIR``, ``/usr/local/cuda``). This is used for NVRTC kernel compilation.
Runtime Environment Variables
------------------------------
These environment variables control the behavior of Transformer Engine during execution.
Attention Backend Selection
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_FLASH_ATTN
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Enable or disable FlashAttention backend for DotProductAttention. When set to ``0``, FlashAttention will not be used.
.. envvar:: NVTE_FUSED_ATTN
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Enable or disable FusedAttention backend (cuDNN-based) for DotProductAttention. When set to ``0``, FusedAttention will not be used.
.. envvar:: NVTE_UNFUSED_ATTN
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Enable or disable UnfusedDotProductAttention backend (native PyTorch). When set to ``0``, UnfusedDotProductAttention will not be used.
.. envvar:: NVTE_FUSED_ATTN_BACKEND
:Type: ``int`` (0, 1, or 2)
:Default: Auto-selected
:Description: Force a specific FusedAttention backend. ``0`` = F16_max512_seqlen (cuDNN, ≤512 seq len), ``1`` = F16_arbitrary_seqlen (cuDNN, any seq len), ``2`` = FP8 backend. If not set, the backend is automatically selected based on the input configuration.
.. envvar:: NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT
:Type: ``int`` (0 or 1)
:Default: Auto-determined
:Description: Control workspace-related optimizations in FusedAttention. ``0`` disables optimizations, ``1`` enables them. These optimizations trade memory for performance. When unset, Transformer Engine determines the code path based on internal logic. For deterministic behavior with cuDNN ≥8.9.5 and <9.0.0, this is automatically set to ``1``.
.. envvar:: NVTE_FUSED_ATTN_USE_FAv2_BWD
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: When using FusedAttention, use FlashAttention-2 implementation for the backward pass instead of the cuDNN implementation. This can be useful due to performance differences between various versions of flash-attn and FusedAttention.
.. envvar:: NVTE_ALLOW_NONDETERMINISTIC_ALGO
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Allow non-deterministic algorithms for Transformer Engine execution. When set to ``0``, only deterministic algorithms are allowed. This is relevant for both PyTorch and JAX attention implementations.
.. envvar:: NVTE_FUSED_RING_ATTENTION_USE_SCAN
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: **(JAX only)** Use scan loop for ring attention implementation. When set to ``1``, the fused ring attention will use a scan-based iteration approach.
.. envvar:: NVTE_APPLY_QK_LAYER_SCALING
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Apply QK layer scaling in UnfusedDotProductAttention. This is an FP16 training trick required for certain GPT-like models. When set to ``1`` and a layer number is provided, the softmax scale is divided by the layer number, and the layer number is used as the softmax scale during the softmax operation. Only effective when using FP16 dtype and when the layer number is specified.
Context Parallelism
^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_BATCH_MHA_P2P_COMM
:Type: ``int`` (0 or 1)
:Default: ``0`` (or auto-enabled for pre-Blackwell GPUs with CP size 2)
:Description: Use batched P2P communication (``batch_isend_irecv``) for KV exchange in context parallel MultiheadAttention. When enabled, send and receive operations are batched together, which can improve communication efficiency. This is automatically enabled for devices with compute capability < 10.0 (pre-Blackwell GPUs) when context parallel size is 2. Setting this to ``1`` forces batched P2P communication regardless of device architecture.
FP8 Configuration
^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_UNFUSED_FP8_UPDATE
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Use unfused kernel for FP8 amax and scale updates. When set to ``1``, amax and scale updates are computed using separate unfused kernels instead of fused operations.
.. envvar:: NVTE_FP8_DPA_BWD
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Enable FP8 in the backward pass of DotProductAttention. ``1`` = FP8 forward and backward, ``0`` = FP8 forward and FP16/BF16 backward.
.. envvar:: NVTE_DPA_FP8CS_O_in_F16
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: For Float8CurrentScaling in DotProductAttention, use FP16/BF16 for the output tensor in the backward pass. ``1`` = use F16/BF16 output in backward, ``0`` = use FP8 output in backward.
.. envvar:: NVTE_DPA_FP8_RECIPE
:Type: ``str``
:Default: Empty (use same as linear layers)
:Description: Override FP8 recipe for DotProductAttention layers. Valid values: ``"F16"`` (disable FP8), ``"DelayedScaling"``, or ``"Float8CurrentScaling"``. This allows using different FP8 recipes for attention vs. linear layers.
.. envvar:: NVTE_DPA_FP8_RECIPE_DPA
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable FP8 in DotProductAttention when using :envvar:`NVTE_DPA_FP8_RECIPE`. When set to ``1``, the DotProductAttention layer will use the FP8 recipe specified by :envvar:`NVTE_DPA_FP8_RECIPE`. This provides fine-grained control over which attention components use FP8.
.. envvar:: NVTE_DPA_FP8_RECIPE_MHA
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable FP8 in MultiheadAttention (MHA) when using :envvar:`NVTE_DPA_FP8_RECIPE`. When set to ``1``, the MultiheadAttention QKV and output projection layers will use the FP8 recipe specified by :envvar:`NVTE_DPA_FP8_RECIPE`. This provides fine-grained control over which attention components use FP8.
.. envvar:: NVTE_DPA_FP8_FORMAT
:Type: ``str``
:Default: ``"HYBRID"``
:Description: FP8 format for DotProductAttention when switching recipes. Valid values: ``"HYBRID"``, ``"E4M3"``, ``"E5M2"``. Only used when :envvar:`NVTE_DPA_FP8_RECIPE` is set.
.. envvar:: NVTE_DPA_FP8DS_AMAX_ALGO
:Type: ``str``
:Default: ``"most_recent"``
:Description: Amax computation algorithm for DelayedScaling recipe in DotProductAttention. Valid values: ``"most_recent"``, ``"max"``. Only used when :envvar:`NVTE_DPA_FP8_RECIPE` is set to ``"DelayedScaling"``.
.. envvar:: NVTE_DPA_FP8DS_AMAX_HISTLEN
:Type: ``int``
:Default: ``1``
:Description: Amax history length for DelayedScaling recipe in DotProductAttention. Only used when :envvar:`NVTE_DPA_FP8_RECIPE` is set to ``"DelayedScaling"``.
.. envvar:: NVTE_DPA_FP8DS_REDUCE_AMAX
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Reduce amax across distributed ranks for DelayedScaling recipe in DotProductAttention. Only used when :envvar:`NVTE_DPA_FP8_RECIPE` is set to ``"DelayedScaling"``.
.. envvar:: NVTE_UnfusedDPA_Emulate_FP8
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Allow FP8 emulation in UnfusedDotProductAttention. When set to ``1``, UnfusedDotProductAttention can emulate FP8 operations using FP16/BF16 computation.
Kernel Configuration
^^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_USE_FAST_MATH
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable fast math optimizations in runtime-compiled (NVRTC) kernels. This trades numerical accuracy for performance. These optimizations are experimental and inconsistently implemented.
.. envvar:: NVTE_DISABLE_NVRTC
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Disable NVRTC (CUDA Runtime Compilation) support. When set to ``1``, runtime kernel compilation is disabled. This can be useful in environments where NVRTC is not available or not desired.
.. envvar:: NVTE_USE_CUTLASS_GROUPED_GEMM
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Use CUTLASS implementation for grouped GEMM operations instead of cuBLAS. When set to ``1``, enables CUTLASS grouped GEMM kernels, which may provide better performance for certain workloads on Hopper (SM90) GPUs.
.. envvar:: NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Emit a warning when falling back from CUTLASS to cuBLAS for grouped GEMM operations.
Torch Compilation and Fusion
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_TORCH_COMPILE
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Enable PyTorch 2.x ``torch.compile`` support for compatible Transformer Engine operations. When set to ``0``, disables compilation support and uses regular PyTorch eager mode.
.. envvar:: NVTE_BIAS_GELU_NVFUSION
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Enable GELU fusion with bias using NVFusion in PyTorch. When set to ``0``, uses separate bias and GELU operations.
.. envvar:: NVTE_BIAS_DROPOUT_FUSION
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Enable fusion of bias and dropout operations. When set to ``0``, bias and dropout are computed separately.
LayerNorm/RMSNorm SM Margins
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_FWD_LAYERNORM_SM_MARGIN
:Type: ``int``
:Default: ``0``
:Description: Number of SMs (Streaming Multiprocessors) to reserve (not use) during forward LayerNorm/RMSNorm operations. This can be used to control resource allocation and overlap computation with communication.
.. envvar:: NVTE_BWD_LAYERNORM_SM_MARGIN
:Type: ``int``
:Default: ``0``
:Description: Number of SMs to reserve during backward LayerNorm/RMSNorm operations.
.. envvar:: NVTE_INF_LAYERNORM_SM_MARGIN
:Type: ``int``
:Default: ``0``
:Description: Number of SMs to reserve during inference LayerNorm/RMSNorm operations.
GEMM Configuration
^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_EXT_MARGIN_SM
:Type: ``int``
:Default: Total SM count
:Description: External SM margin for GEMM operations. Specifies the number of SMs to use for GEMM operations. The actual number of SMs used is ``sm_count - NVTE_EXT_MARGIN_SM``.
.. envvar:: NVTE_AG_P2P_MULTI_ATOMIC
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable multi-atomic mode for AllGather with atomic GEMM using P2P communication. When set to ``1``, uses ``userbuffers_sendrecv_multiatomic`` for communication during atomic GEMM overlap with AllGather operations. This disables copy engine (CE) usage and enables push mode for userbuffers. This is an advanced optimization for tensor-parallel communication-computation overlap.
CPU Offloading
^^^^^^^^^^^^^^
.. envvar:: NVTE_CPU_OFFLOAD_V1
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable legacy version of CPU offloading implementation.
Debugging and Profiling
^^^^^^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_DEBUG
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable debug mode. When set to ``1``, enables verbose debug output and additional checks in attention operations.
.. envvar:: NVTE_DEBUG_LEVEL
:Type: ``int`` (0, 1, or 2)
:Default: ``0``
:Description: Debug verbosity level. Higher values enable more verbose debug output. Only effective when :envvar:`NVTE_DEBUG` is set to ``1``.
.. envvar:: NVTE_PRINT_LAYER_NUMBER
:Type: ``int``
:Default: ``1``
:Description: Layer number to print debug information for during attention operations.
.. envvar:: NVTE_PRINT_RANK
:Type: ``int``
:Default: ``0``
:Description: Distributed rank to print debug information for during attention operations.
.. envvar:: NVTE_NVTX_ENABLED
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable NVTX (NVIDIA Tools Extension) range profiling for Transformer Engine operations. When set to ``1``, NVTX markers are added to operations for profiling with NVIDIA Nsight Systems.
.. envvar:: NVTE_DEBUG_NUMERICS
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: **(JAX only)** Enable verbose printing of tensor numerics for debugging purposes.
Testing
^^^^^^^
.. envvar:: NVTE_TEST_NVINSPECT_ENABLED
:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable NVInspect integration for testing. When set to ``1``, enables the NVInspect debugging API for numerical analysis during tests.
.. envvar:: NVTE_TEST_NVINSPECT_CONFIG_FILE
:Type: ``str``
:Default: None
:Description: Path to NVInspect configuration file. Required when :envvar:`NVTE_TEST_NVINSPECT_ENABLED` is set to ``1``.
.. envvar:: NVTE_TEST_NVINSPECT_FEATURE_DIRS
:Type: ``str``
:Default: None
:Description: Comma-separated list of directories containing NVInspect features. Required when :envvar:`NVTE_TEST_NVINSPECT_ENABLED` is set to ``1``.
.. envvar:: NVTE_TEST_ARTIFACTS_DIR
:Type: ``str``
:Default: System temp directory
:Description: Directory for storing test artifacts (e.g., generated ONNX models).
ONNX Export
^^^^^^^^^^^
.. envvar:: NVTE_ONNX_KVCACHE_MAX_SEQ_LEN
:Type: ``int``
:Default: ``128``
:Description: Maximum sequence length for KV cache during ONNX export. This is used for attention masking in exported ONNX models.
JAX-Specific Variables
^^^^^^^^^^^^^^^^^^^^^^
.. envvar:: NVTE_JAX_CUSTOM_CALLS
:Type: ``str``
:Default: None
:Description: Control which JAX custom call primitives are enabled or disabled. Format: ``"true"`` (enable all), ``"false"`` (disable all), or comma-separated key-value pairs like ``"GemmPrimitive=false,DBiasQuantizePrimitive=true"``. This provides fine-grained control over which operations use custom CUDA kernels vs. JAX native implementations.
.. envvar:: NVTE_JAX_CUSTOM_CALLS_RE
:Type: ``str``
:Default: None
:Description: **Deprecated** (use :envvar:`NVTE_JAX_CUSTOM_CALLS` instead). Regex pattern to match primitive names for enabling/disabling. Example: ``"DBiasQuantizePrimitive"`` or ``"^(?!DBiasQuantizePrimitive$).+$"``.
.. envvar:: NVTE_JAX_UNITTEST_LEVEL
:Type: ``str``
:Default: None
:Description: Test level for JAX unit tests (``"L0"``, ``"L1"``, ``"L2"``). Used internally by the test suite.
Examples
--------
Building with Debug Symbols
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
export NVTE_BUILD_DEBUG=1
export NVTE_USE_CCACHE=1
pip install -e .
Using Specific Attention Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
# Use only FlashAttention, disable FusedAttention
export NVTE_FLASH_ATTN=1
export NVTE_FUSED_ATTN=0
python train.py
Configuring FP8 for Attention
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
# Use DelayedScaling for attention, CurrentScaling for linear layers
export NVTE_DPA_FP8_RECIPE="DelayedScaling"
export NVTE_DPA_FP8_FORMAT="HYBRID"
export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent"
export NVTE_DPA_FP8DS_AMAX_HISTLEN=1024
python train.py
Enable Profiling
^^^^^^^^^^^^^^^^
.. code-block:: bash
# Enable NVTX markers for profiling
export NVTE_NVTX_ENABLED=1
nsys profile --trace=nvtx,cuda python train.py
JAX Custom Calls Control
^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
# Disable all custom calls
export NVTE_JAX_CUSTOM_CALLS="false"
python train_jax.py
# Disable specific primitives
export NVTE_JAX_CUSTOM_CALLS="GemmPrimitive=false,DBiasQuantizePrimitive=false"
python train_jax.py
......@@ -13,7 +13,7 @@
"id": "6dcbf25a",
"metadata": {},
"source": [
"This guide is a follow-up to the discussion in the [quickstart guide](quickstart.ipynb). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). "
"This guide is a follow-up to the discussion in the [Getting Started guide](../getting_started/index.rst). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). "
]
},
{
......@@ -100,7 +100,7 @@
"\n",
"</div>\n",
"\n",
"A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their $\\text{sequence_length} \\times \\text{batch_size} \\times \\text{hidden_size}$ activation tensors. The most common approach is data parallelism, which distributes along the $\\text{batch_size}$ dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the $\\text{hidden_size}$ dimension. This allows us to scale past the limits of data parallelism (typically $\\text{hidden_size} > \\text{batch_size}$) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Finally, sequence parallelism distributes along the $\\text{sequence_length}$ dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see [this paper](https://arxiv.org/pdf/2205.05198.pdf).\n",
"A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their $\\text{sequence_length} \\cdot \\text{batch_size} \\cdot \\text{hidden_size}$ activation tensors. The most common approach is data parallelism, which distributes along the $\\text{batch_size}$ dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the $\\text{hidden_size}$ dimension. This allows us to scale past the limits of data parallelism (typically $\\text{hidden_size} > \\text{batch_size}$) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Finally, sequence parallelism distributes along the $\\text{sequence_length}$ dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see [this paper](https://arxiv.org/pdf/2205.05198.pdf).\n",
"\n",
"To show this in action, let's first initialize NCCL with a trivial process group:"
]
......@@ -131,7 +131,7 @@
"id": "1f2b80d0",
"metadata": {},
"source": [
"We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\times \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n",
"We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\cdot \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n",
"\n",
"Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n",
"\n",
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
......@@ -174,7 +174,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "50852cb5",
"metadata": {},
"outputs": [
......@@ -266,7 +266,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": null,
"id": "906b8cf1",
"metadata": {},
"outputs": [
......@@ -299,7 +299,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"id": "d3637094",
"metadata": {},
"outputs": [
......@@ -509,10 +509,10 @@
"\n",
"* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
" - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
" - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
" - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor of shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors of shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
"\n",
"\n",
"* JAX: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"* JAX: Users should provide the `attention_mask` tensor of shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"\n",
"**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
"\n",
......@@ -521,7 +521,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": null,
"id": "a1f25a9b",
"metadata": {},
"outputs": [
......
{
"cells": [
{
"cell_type": "markdown",
"id": "14efeb1e",
"metadata": {},
"source": [
"## Deep Dive into CP + THD + AG + Striped>1 + SWA support for Transformer Engine JAX\n",
"This feature was merged as part of [PR 2379](https://github.com/NVIDIA/TransformerEngine/pull/2379/) and was made available in Transformer Engine v2.11. This document addresses 3 fundamental questions about the design considerations and the implementation logic for this feature."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16f738c7",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "2f31119f",
"metadata": {},
"source": [
"### Question 1: Why choose Striped>1 ?\n",
"\n",
"Prior to the addition of this feature, Transformer Engine JAX attention already supported load balancing via a striping pattern, i.e., `stripe_size=1` for `CP + THD + P2P(Ring) + Striped + SWA`. However, this reordering technique does not lend itself well to an all-gathered (post-AG) pattern. The following example illustrates this distinction. For this example, `cp_size=4`, `num_segments=4`, `window_size=(8,0)`, and the pattern is for a single rank after striped reordering has been performed: \n",
"\n",
"#### I. Striped (`stripe_size=1`)\n",
"- Such a staggered pattern is not supported by cuDNN\n",
"- One possible way to express this with cuDNN support is by treating each `q` token as a segment, thereby producing 16 segments with varying `kv` token counts. However, this is very inefficient and does not scale well as max_seqlens increases\n",
"\n",
"```\n",
"1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 1 1 1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - 1 1 1 1 1 1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 4 4 - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 4 4 4 4 4 4 - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 4 4 4 4 4 4 4 4 - - -\n",
"```\n",
"<figure align=\"left\">\n",
"<figcaption> Figure 1: Post load balancing using stripe_size=1 and post AG attention pattern for a single cp rank </figcaption>\n",
"</figure>\n",
"\n",
"\n",
"#### II. Striped > 1 (`stripe_size > 1`)\n",
"- This pattern is supported by cuDNN, with a suggested `stripe_size=128`\n",
"- The mask type supported by `CP + THD + AG + Striped>1 + SWA` is `PADDING_CAUSAL_MASK`; however, to express the pattern below, each rank executes THD + SWA using `PADDING_BOTTOM_RIGHT_CAUSAL_MASK`\n",
"- `max_num_segments_for_rank` needs to be estimated. The estimation formula used is: `max_seqlens // (stripe_size * cp_size) + max_num_segments`\n",
"\n",
"```\n",
"1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 4 - - - - - - - - - - - -\n",
"```\n",
"<figure align=\"left\">\n",
"<figcaption> Figure 2: Post load balancing using stripe_size=4 and post AG attention pattern for a single cp rank </figcaption>\n",
"</figure>"
]
},
{
"cell_type": "markdown",
"id": "6eddfa7a",
"metadata": {},
"source": [
"### Question 2: Why is there a need for separate helper functions for calculating seqlens and offsets ?\n",
"\n",
"The seqlens and offsets are calculated by the fused attn JAX primitives (both, CP and non-CP) so that they can be passed down to `fused_attn_arbitrary_seqlen_fwd_impl()` / `fused_attn_arbitrary_seqlen_bwd_impl()`, where it is translated before passing down to the cuDNN FE layer. The current (Transformer Engine v2.10) calculation of seqlens and offsets entails the CP primitive passing the sharded segment_ids, segment_pos, seq_lens, seq_offsets stuffed in a SequenceDescriptor object (a convenience class provided for packing these 4 tensors) to the `FusedAttnPrimitive`, which in turn calls `get_seqlens_and_offsets()` on the SequenceDescriptor object. \n",
"\n",
"If `get_seqlens_and_offsets()` receives a SequenceDescriptor object with seq_lens and seq_offsets populated and, segment_ids, segment_pos with size=0, it returns the seq_lens and seq_ofsets as it is (for e.g. `CP + BSHD + AG`). However, if `get_seqlens_and_offsets()` receives a SequenceDescriptor object with segment_ids and segment_pos populated and, seq_lens, seq_offsets with size=0, it first constructs a mask using the segment_ids and segment_pos and then extracts the seq_lens and seq_offsets from it and then returns it (for e.g. `CP + THD + P2P`).\n",
"\n",
"The problem with the current approach of calculating a mask followed by extracting the seq_lens and seq_offsets is that it is unable to express the patterns seen in `CP + THD + AG`. Below is one such example: \n",
"\n",
"```\n",
"1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - -\n",
"```\n",
"<figure align=\"left\">\n",
"<figcaption> Figure 3: Example 1 for problem using mask path in get_seqlens_and_offsets() for attention pattern (post striping and AG) . </figcaption>\n",
"</figure>\n",
"\n",
"Here, ideally, the two sections of the segment 3 should be split into two different segments (segment 3_1 formed using rows 9-12 and segment 3_2 formed using rows 13-16) as cuDNN does not support segment 3's entire staggered shape (as discussed earlier) , however, the mask route is unable to make this distinction, and it ends up treating it as one large segment thereby performing unnecessary computations of the padded regions in segment 3(rows 9-12 )\n",
"\n",
"In the below example, the mask route takes the `kv_seqlens` for segment 1 to be 6 and masks it using Bottom Right Causal Mask rather than taking `kv_seqlens` of 4 and masks it using Bottom Right Causal Mask, resulting in incorrect results\n",
"\n",
"```\n",
"1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"```\n",
"<figure align=\"left\">\n",
"<figcaption> Figure 4: Example 2 for problem using mask path in get_seqlens_and_offsets() for attention pattern (post striping and AG) </figcaption>\n",
"</figure>\n",
"\n",
"The second case can be resolved in the mask path, but that would require adding CP specific details to the non-CP FusedAttn primitive which would contaminate it. Besides, resolving the first case would be even trickier with this approach. Due to it being incompatible with the design of FusedAttn primitive and inadequate to express the pattern needed for `CP + THD + AG` fully, separate helper functions were created which calculate the seqlens and seqoffsets, without creating a mask, hence also being O(N) space."
]
},
{
"cell_type": "markdown",
"id": "3cc4a12c",
"metadata": {},
"source": [
"### Question 3: What is the implementation logic for the separate helper functions ?\n",
"\n",
"This section discusses the implementation logic for two of these four helper functions which serve as a reference, as the other two are using similar principles. Consider the test example in the code block, for which, `cp_size=4`, `stripe_size=4`, `max_seqlens=64`, `num_segments=2` and no SWA for simplicity. seg_1 has 8 valid tokens + 13 padded tokens and seg_2 has 31 valid tokens + 1 padded token. The 0 is used to explicitly show the padded region of seg_1 which is reordered, but for computation purposes it is equivalent to any of the `-` marked elements.\n",
"\n",
"```\n",
"segment_ids_q_0_reordered = segment_ids_kv_0_reordered = jnp.array([[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]])\n",
"\n",
"segment_pos_q_0_reordered = segment_pos_kv_0_reordered = jnp.array([[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]])\n",
"\n",
"segment_ids_kv_0_seed12_ag_inv_reordered = jnp.array([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n",
"\n",
"segment_pos_kv_0_seed12_ag_inv_reordered= jnp.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n",
"```\n",
"\n",
"```\n",
"1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - -\n",
"```\n",
"<figure align=\"left\">\n",
"<figcaption> Figure 5: An example of post striped reordering and AG attention pattern on a single rank.</figcaption>\n",
"</figure>\n",
"\n",
"#### I. Implementation logic for q_seqlens_for_striped_for_rank()\n",
"**What is the objective/logic ?**\n",
"- Create a new set of segment ids for this rank such that:\n",
" - It gets rid of padding information as it does not contribute to the seqlens calculation\n",
" - It has the ability to identify ”new segments” being created from the same original segment\n",
"- Use this new set of segment ids to calculate the seqlens\n",
"\n",
"**Example walkthrough**\n",
"1. Calculate the non-zero indices (where seg ids !=0)\n",
"2. Calculate the valid seg ids and valid seg pos (i.e. index into seg ids and seg pos using the non-zero indices)\n",
" - `valid_segment_ids=[[1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0]]`\n",
" - `valid_segment_pos=[[0, 1, 2, 3, 11, 12, 13, 14, 27, 28, 29, 30, 0, 0, 0, 0]]`\n",
" - Ignore the 0s at the end of the two arrays as they are just for padding to a static length\n",
"3. Find locations where a q segment change/break happens. A segment change happens when: \n",
" - there is a change in valid_segment_ids OR \n",
" - `valid_segment_pos[i+1] != valid_segment_pos[i]`\n",
" - `segment_changes=[[True, False, False, False, True, False, False, False, True, False, False, False, True, True, True, True]]`\n",
"4. Perform a cumulative sum on the segment changes: \n",
" - `new_segment_ids=[[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 5, 6, 7]]`\n",
"5. Filter out the valid indices only and pad at the end with 0s upto static length (these are our “new” segment indices without padding)\n",
" - `new_segment_ids_filtered=[[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0]]`\n",
" - Notice here that the large chunk of 8 q token rows (rows 9-16 in Fig 5) gets broken down into 2 \"new\" segments of 4 q token rows each,\n",
" which is a pattern that cuDNN supports and it ensures that wasted computation for padded regions of rows 9-12 is not performed, which was the\n",
" case in Fig 3\n",
"6. Perform a bin count and pad with -1s upto `max_num_segments_per_seq_for_rank`\n",
" - `seqlens_with_neg1_padding[[ 4, 4, 4, -1, -1, -1, -1]]`\n",
"\n",
"\n",
"#### II. Implementation logic for kv_seqoffsets_for_striped_for_rank()\n",
"**What is the objective/logic ?**\n",
"- Get the original segment ids for those locations where segment changes happen (arr1)\n",
" - Each segment has a known kv offset, hence if we know which original segment id a \"new\" segment is associated with we can find it's kv offset\n",
" - So, for e.g., in Fig 5, all valid tokens of seg_3 have the same kv offset, so even if this gets split into a 2 \"new\" segments, we can procure the offset for both using a mapping of original seg-ids to kv offset \n",
"- Get the segment ids for those locations where segment changes happen in the AG tensor (arr2)\n",
" - This is used to create a kind of mapping between original seg-ids to kv offset\n",
"- Pick values from arr2 mapping for the \"new\" segment ids collected in arr1\n",
"\n",
"**Example walkthrough**\n",
"1. Find locations where a kv segment pos change/break happens and mask out zero seg ids. A segment change happens when: \n",
" - `kv_segment_pos[i+1] != kv_segment_pos[i]`\n",
" - `segment_changes_masked=[[ True, False, False, False, False, False, False, False, True, False, False, False, True, False, False, False]]`\n",
"2. Get the indices where the segment changes happen and the segment ids associated with them:\n",
" - `segment_changes_indices=[[0, 8, 12, -1, -1, -1, -1, -1, -1]]`\n",
" - `[[1, 2, 2, -1, -1, -1, -1, -1, -1]]`\n",
"3. Find the segment pos changes/break for the AG seg pos and mask out zero seg ids\n",
" - `segment_changes_masked_ag=[[True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]]`\n",
"4. Get indices where the segment changes happen for the AG seg pos (this works as a mapping between segment ids and kv offsets)\n",
" - `segment_changes_ag_indices=[[0, 21, -1, -1, -1, -1, -1, -1, -1]]`\n",
"5. Get the seq offsets by indexing into segment_changes_ag_indices using segment_changes_indices :\n",
" - `kv_seq_offsets[[0, 21, 21, -1, -1, -1, -1, -1, -1]]`\n",
"\n",
"The implementation details for `q_seqoffsets_for_striped_for_rank()` and `kv_seqlens_for_striped_for_rank()` can be found in [PR 2379](https://github.com/NVIDIA/TransformerEngine/pull/2379/)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
{
"cells": [
{
"cell_type": "markdown",
"id": "da9fd6a8",
"metadata": {},
"source": [
"# Getting Started\n",
"\n",
"## Overview\n",
"\n",
"Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your PyTorch code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n",
"\n",
"## Let's build a Transformer layer!\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We build a basic Transformer layer using regular PyTorch modules. This will be our baseline for later comparisons with Transformer Engine.\n",
"\n",
"</div>\n",
"\n",
"Let's start with creating a GPT encoder layer using plain PyTorch. Figure 1 shows the overall structure.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"transformer_layer.png\" width=\"20%\">\n",
"<figcaption> Figure 1: Structure of a GPT encoder layer.</figcaption>\n",
"</figure>\n",
"\n",
"We construct the components as follows:\n",
"\n",
"- `LayerNorm`: `torch.nn.LayerNorm`\n",
"- `QKV Projection`: `torch.nn.Linear` (conceptually three `Linear` layers for Q, K, and V separately, but we fuse into a single `Linear` layer that is three times larger)\n",
"- `DotProductAttention`: `DotProductAttention` from [quickstart_utils.py](quickstart_utils.py)\n",
"- `Projection`: `torch.nn.Linear`\n",
"- `Dropout`: `torch.nn.Dropout`\n",
"- `MLP`: `BasicMLP` from [quickstart_utils.py](quickstart_utils.py)\n",
"\n",
"Over the course of this tutorial we will use a few modules and helper functions defined in [quickstart_utils.py](quickstart_utils.py). Putting it all together:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2be43d64",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import quickstart_utils as utils\n",
"\n",
"class BasicTransformerLayer(torch.nn.Module):\n",
" def __init__(\n",
" self,\n",
" hidden_size: int,\n",
" ffn_hidden_size: int,\n",
" num_attention_heads: int,\n",
" layernorm_eps: int = 1e-5,\n",
" attention_dropout: float = 0.1,\n",
" hidden_dropout: float = 0.1,\n",
" ):\n",
" super().__init__()\n",
" self.num_attention_heads = num_attention_heads\n",
" self.kv_channels = hidden_size // num_attention_heads\n",
" self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)\n",
" self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True)\n",
" self.attention = utils.DotProductAttention(\n",
" num_attention_heads=num_attention_heads,\n",
" kv_channels=self.kv_channels,\n",
" attention_dropout=attention_dropout,\n",
" )\n",
" self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True)\n",
" self.dropout = torch.nn.Dropout(hidden_dropout)\n",
" self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)\n",
" self.mlp = utils.BasicMLP(\n",
" hidden_size=hidden_size,\n",
" ffn_hidden_size=ffn_hidden_size,\n",
" ) \n",
" \n",
" def forward(\n",
" self, \n",
" x: torch.Tensor, \n",
" attention_mask: torch.Tensor\n",
" ) -> torch.Tensor:\n",
" res = x\n",
" x = self.ln1(x)\n",
" \n",
" # Fused QKV projection\n",
" qkv = self.qkv_projection(x)\n",
" qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n",
" \n",
" x = self.attention(q, k, v, attention_mask)\n",
" x = self.projection(x)\n",
" x = self.dropout(x)\n",
" x = res + x\n",
" res = x\n",
" x = self.ln2(x)\n",
" x = self.mlp(x)\n",
" \n",
" return x + res"
]
},
{
"cell_type": "markdown",
"id": "40724d1d",
"metadata": {},
"source": [
"That's it! We now have a simple Transformer layer. We can test it:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a786f0ea",
"metadata": {},
"outputs": [],
"source": [
"# Layer configuration\n",
"hidden_size = 4096\n",
"sequence_length = 2048\n",
"batch_size = 4\n",
"ffn_hidden_size = 16384\n",
"num_attention_heads = 32\n",
"dtype = torch.float16\n",
"\n",
"# Synthetic data\n",
"x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)\n",
"dy = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ffdbfb7a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"BasicTransformerLayer(\n",
" (ln1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
" (qkv_projection): Linear(in_features=4096, out_features=12288, bias=True)\n",
" (attention): DotProductAttention(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (projection): Linear(in_features=4096, out_features=4096, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (ln2): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): BasicMLP(\n",
" (linear1): Linear(in_features=4096, out_features=16384, bias=True)\n",
" (linear2): Linear(in_features=16384, out_features=4096, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"basic_transformer = BasicTransformerLayer(\n",
" hidden_size,\n",
" ffn_hidden_size,\n",
" num_attention_heads,\n",
")\n",
"basic_transformer.to(dtype=dtype).cuda()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0162ad40",
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(1234)\n",
"y = basic_transformer(x, attention_mask=None)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "65ae6dd6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 43.0663916015625 ms\n"
]
}
],
"source": [
"utils.speedometer(\n",
" basic_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "43717e36",
"metadata": {},
"source": [
"## Meet Transformer Engine\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We modify the example Transformer layer to include the simplest TE modules: `Linear` and `LayerNorm`.\n",
"\n",
"</div>\n",
"\n",
"Now that we have a basic Transformer layer, let's use Transformer Engine to speed up the training. "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "004d3c92",
"metadata": {},
"outputs": [],
"source": [
"import transformer_engine.pytorch as te"
]
},
{
"cell_type": "markdown",
"id": "1931f911",
"metadata": {},
"source": [
"TE provides a set of PyTorch modules that can be used to build Transformer layers. The simplest of the provided modules are the `Linear` and `LayerNorm` layers, which we can use instead of `torch.nn.Linear` and `torch.nn.LayerNorm`. Let's modify `BasicTransformerLayer`:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1f44db50",
"metadata": {},
"outputs": [],
"source": [
"class BasicTEMLP(torch.nn.Module):\n",
" def __init__(self,\n",
" hidden_size: int,\n",
" ffn_hidden_size: int) -> None:\n",
" super().__init__()\n",
" self.linear1 = te.Linear(hidden_size, ffn_hidden_size, bias=True)\n",
" self.linear2 = te.Linear(ffn_hidden_size, hidden_size, bias=True)\n",
"\n",
" def forward(self, x):\n",
" x = self.linear1(x)\n",
" x = torch.nn.functional.gelu(x, approximate='tanh')\n",
" x = self.linear2(x)\n",
" return x \n",
" \n",
"class BasicTETransformerLayer(torch.nn.Module):\n",
" def __init__(self,\n",
" hidden_size: int,\n",
" ffn_hidden_size: int,\n",
" num_attention_heads: int,\n",
" layernorm_eps: int = 1e-5,\n",
" attention_dropout: float = 0.1,\n",
" hidden_dropout: float = 0.1):\n",
" super().__init__()\n",
" self.num_attention_heads = num_attention_heads\n",
" self.kv_channels = hidden_size // num_attention_heads\n",
" self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)\n",
" self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)\n",
" self.attention = utils.DotProductAttention(\n",
" num_attention_heads=num_attention_heads,\n",
" kv_channels=self.kv_channels,\n",
" attention_dropout=attention_dropout,\n",
" )\n",
" self.projection = te.Linear(hidden_size, hidden_size, bias=True)\n",
" self.dropout = torch.nn.Dropout(hidden_dropout)\n",
" self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)\n",
" self.mlp = BasicTEMLP(\n",
" hidden_size=hidden_size,\n",
" ffn_hidden_size=ffn_hidden_size,\n",
" )\n",
" \n",
" def forward(self, \n",
" x: torch.Tensor, \n",
" attention_mask: torch.Tensor):\n",
" res = x\n",
" x = self.ln1(x)\n",
" \n",
" # Fused QKV projection\n",
" qkv = self.qkv_projection(x)\n",
" qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n",
" \n",
" x = self.attention(q, k, v, attention_mask)\n",
" x = self.projection(x)\n",
" x = self.dropout(x)\n",
" x = res + x\n",
" res = x\n",
" x = self.ln2(x)\n",
" x = self.mlp(x)\n",
" \n",
" return x + res"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "916531e8",
"metadata": {},
"outputs": [],
"source": [
"basic_te_transformer = BasicTETransformerLayer(\n",
" hidden_size, \n",
" ffn_hidden_size, \n",
" num_attention_heads,\n",
")\n",
"basic_te_transformer.to(dtype=dtype).cuda()\n",
"utils.share_parameters_with_basic_te_model(basic_te_transformer, basic_transformer)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3643fa54",
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(1234)\n",
"y = basic_te_transformer(x, attention_mask=None)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "10b92894",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 43.1413232421875 ms\n"
]
}
],
"source": [
"utils.speedometer(\n",
" basic_te_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "3f990226",
"metadata": {},
"source": [
"## Fused TE Modules\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We optimize the example Transformer layer with TE modules for fused operations.\n",
"\n",
"</div>\n",
"\n",
"The `Linear` layer is enough to build any Transformer model and it enables usage of Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations like kernel fusion, increasing the achievable speedup.\n",
"\n",
"Transformer Engine therefore provides coarser modules that span multiple layers:\n",
"\n",
"* `LayerNormLinear`\n",
"* `LayerNormMLP`\n",
"* `TransformerLayer`\n",
"\n",
"Building a third iteration of our Transformer layer with `LayerNormLinear` and `LayerNormMLP`:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c55eae1f",
"metadata": {},
"outputs": [],
"source": [
"class FusedTETransformerLayer(torch.nn.Module):\n",
" def __init__(self,\n",
" hidden_size: int,\n",
" ffn_hidden_size: int,\n",
" num_attention_heads: int,\n",
" layernorm_eps: int = 1e-5,\n",
" attention_dropout: float = 0.1,\n",
" hidden_dropout: float = 0.1):\n",
" super().__init__()\n",
" self.num_attention_heads = num_attention_heads\n",
" self.kv_channels = hidden_size // num_attention_heads\n",
" self.ln_qkv = te.LayerNormLinear(hidden_size, 3 * hidden_size, eps=layernorm_eps, bias=True)\n",
" self.attention = utils.DotProductAttention(\n",
" num_attention_heads=num_attention_heads,\n",
" kv_channels=self.kv_channels,\n",
" attention_dropout=attention_dropout,\n",
" )\n",
" self.projection = te.Linear(hidden_size, hidden_size, bias=True)\n",
" self.dropout = torch.nn.Dropout(hidden_dropout)\n",
" self.ln_mlp = te.LayerNormMLP(hidden_size, ffn_hidden_size, eps=layernorm_eps, bias=True)\n",
" \n",
" \n",
" def forward(self, \n",
" x: torch.Tensor, \n",
" attention_mask: torch.Tensor):\n",
" res = x\n",
" qkv = self.ln_qkv(x)\n",
" \n",
" # Split qkv into query, key and value\n",
" qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n",
" \n",
" x = self.attention(q, k, v, attention_mask)\n",
" x = self.projection(x)\n",
" x = self.dropout(x)\n",
" x = res + x\n",
" res = x\n",
" x = self.ln_mlp(x)\n",
" \n",
" return x + res"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "85949421",
"metadata": {},
"outputs": [],
"source": [
"fused_te_transformer = FusedTETransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n",
"fused_te_transformer.to(dtype=dtype).cuda()\n",
"utils.share_parameters_with_fused_te_model(fused_te_transformer, basic_transformer)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "2c263e71",
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(1234)\n",
"y = fused_te_transformer(x, attention_mask=None)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "24e101bc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 43.1981201171875 ms\n"
]
}
],
"source": [
"utils.speedometer(\n",
" fused_te_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "33f13c26",
"metadata": {},
"source": [
"Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures and it provides the highest degree of performance optimization:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "ec8c3685",
"metadata": {},
"outputs": [],
"source": [
"te_transformer = te.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n",
"te_transformer.to(dtype=dtype).cuda()\n",
"utils.share_parameters_with_transformerlayer_te_model(te_transformer, basic_transformer)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e48cd590",
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(1234)\n",
"y = te_transformer(x, attention_mask=None)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "3ec3707d-e63f-4899-8308-b11c55b5caa4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 39.99169921875 ms\n"
]
}
],
"source": [
"utils.speedometer(\n",
" te_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "4034c3eb-8958-49f2-85f6-30c94977d884",
"metadata": {},
"source": [
"## Enabling FP8\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We configure a TE module to perform compute in FP8.\n",
"\n",
"</div>\n",
"\n",
"Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager. Note that autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "31256aa7-3d5e-425c-91ab-502b1326a748",
"metadata": {},
"outputs": [],
"source": [
"from transformer_engine.common.recipe import Format, DelayedScaling\n",
"\n",
"te_transformer = te.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n",
"te_transformer.to(dtype=dtype).cuda()\n",
"utils.share_parameters_with_transformerlayer_te_model(te_transformer, basic_transformer)\n",
"\n",
"fp8_format = Format.HYBRID\n",
"fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n",
"torch.manual_seed(1234)\n",
"with te.autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" y = te_transformer(x, attention_mask=None)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "793ebd2d-b84b-47bc-811a-7991df8500aa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 28.61394775390625 ms\n"
]
}
],
"source": [
"utils.speedometer(\n",
" te_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n",
")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import jax
import jax.numpy as jnp
import time
from typing import Callable, Any, Dict, Optional, Tuple
import transformer_engine.jax as te
def speedometer(
model_apply_fn: Callable,
variables: Any,
input: jnp.ndarray,
output_grad: jnp.ndarray,
model_init_fn: Callable = None,
forward_kwargs: dict = {},
autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50,
warmup_iters: int = 50,
rngs: Dict[str, jax.random.PRNGKey] = None,
) -> None:
"""Measure average runtime for a JAX module
Perform forward and backward passes .
"""
if autocast_kwargs is None:
autocast_kwargs = {"enabled": False}
model_init_fn = None
if rngs is None:
rngs = {}
train_step_fn = create_train_step_fn(model_apply_fn, autocast_kwargs, forward_kwargs)
# Warm up runs
for _ in range(warmup_iters):
rngs, step_rngs = _split_step_rngs(rngs)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_rngs)
# Timing runs
start = time.time()
for _ in range(timing_iters):
rngs, step_rngs = _split_step_rngs(rngs)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_rngs)
end = time.time()
print(f"Mean time: {(end - start) * 1000 / timing_iters} ms")
def create_train_step_fn(
model_apply_fn: Callable,
autocast_kwargs: Dict[str, Any],
forward_kwargs: Dict[str, Any] = None,
) -> Callable:
"""
Creates a JIT-compiled function that performs one forward/backward pass.
"""
if forward_kwargs is None:
forward_kwargs = {}
def loss_fn(
variables: Any,
inp: jnp.ndarray,
grad_target: jnp.ndarray,
rngs: Dict[str, jax.random.PRNGKey],
):
with te.autocast(**autocast_kwargs):
# Forward Pass: Apply the model using current parameters and variables
call_kwargs = {**forward_kwargs, "rngs": rngs}
out = model_apply_fn(variables, inp, **call_kwargs)
# grad_target = derivative of L (loss fn) over y (output) = signma(L)/sigma(y)
# where grad_w(L) = gradient of loss over params = sigma(L)/sigma(y) * sigma(y)/sigma(w) --> chain rule
# sigma(y)/sigma(w) = J_model(w)
return jnp.vdot(out, grad_target)
def fwd_bwd_fn(*args, **kwargs):
return jax.value_and_grad(loss_fn, argnums=(0, 1))(*args, **kwargs)
# Use jax.value_and_grad to get the loss value and gradients simultaneously. (forward + backward pass)
# ∇_params[output^T · grad_target] = grad_target^T · J_output(params) = VJP
# fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))
# JIT-compile the fwd_bwd_fn
return jax.jit(fwd_bwd_fn)
def _split_step_rngs(
rngs: Dict[str, jax.random.PRNGKey],
) -> Tuple[Dict[str, jax.random.PRNGKey], Dict[str, jax.random.PRNGKey]]:
"""Splits each RNG in the rngs dictionary for a new step."""
step_rngs = {}
new_rngs = {}
for name, key in rngs.items():
new_key, step_key = jax.random.split(key)
new_rngs[name] = new_key
step_rngs[name] = step_key
return new_rngs, step_rngs
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
......@@ -38,7 +38,7 @@
"\n",
"For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n",
"\n",
"In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n",
"In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n",
"\n",
"This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n",
"\n",
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
{
"cells": [
{
"cell_type": "markdown",
"id": "962d87bb",
"metadata": {},
"source": [
"\n",
"\n",
"# JAX: Integrating TE into an existing framework\n",
"\n",
"This tutorial will cover how to integrate TransformerEngine into an existing JAX model framework, such as [MaxText's TE integration](https://github.com/AI-Hypercomputer/maxtext/blob/ed517cf80d9aa81f76e236c5516dacebfe39e96d/src/MaxText/layers/quantizations.py#L753) or your own model framework. \n"
]
},
{
"cell_type": "markdown",
"id": "b36876bb",
"metadata": {},
"source": [
"Let's start with a standard JAX+Flax Transformer layer"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d5284a38",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from flax import linen as nn\n",
"import quickstart_jax_utils as utils\n",
"from typing import Optional"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a4d1cfdc",
"metadata": {},
"outputs": [],
"source": [
"class FlaxMLP(nn.Module):\n",
" \"\"\"Feed-forward network in Transformer layer\n",
" Built with plain Flax modules.\n",
" \"\"\"\n",
" hidden_size: int\n",
" ffn_hidden_size: int\n",
" dot_general_cls: callable = lambda: None\n",
"\n",
" @nn.compact\n",
" def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n",
" x = nn.Dense(features=self.ffn_hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n",
" x = nn.gelu(x, approximate=True) # equivalent to tanh approximation\n",
" x = nn.Dense(features=self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n",
" return x\n",
"\n",
"class FlaxTransformerLayer(nn.Module):\n",
" \"\"\"Basic Transformer layer using plain Flax modules\"\"\"\n",
" hidden_size: int\n",
" ffn_hidden_size: int\n",
" num_attention_heads: int\n",
" layernorm_eps: float = 1e-5\n",
" attention_dropout: float = 0.1\n",
" dot_general_cls: callable = lambda: None\n",
" \n",
" def setup(self):\n",
" self.kv_channels = self.hidden_size // self.num_attention_heads\n",
"\n",
" @nn.compact\n",
" def __call__(\n",
" self, \n",
" x: jnp.ndarray, \n",
" attention_mask: Optional[jnp.ndarray] = None,\n",
" deterministic: bool = False\n",
" ) -> jnp.ndarray:\n",
" # Create causal mask if not provided\n",
" if attention_mask is None:\n",
" attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
" \n",
" res = x\n",
" x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
" \n",
" # Fused QKV projection\n",
" qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n",
" qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = jnp.split(qkv, 3, axis=3)\n",
" \n",
" # q, k, v now have shape [batch, seq_len, num_heads, kv_channels]\n",
" # which is the correct format for dot_product_attention\n",
" \n",
" # Apply dot product attention\n",
" # Note: dot_product_attention expects mask to be broadcastable to \n",
" # [batch, num_heads, q_length, kv_length], but attention_mask from \n",
" # nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]\n",
" \n",
" # Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)\n",
" dropout_rng = None\n",
" if not deterministic and self.attention_dropout > 0:\n",
" dropout_rng = self.make_rng('dropout')\n",
" \n",
" # See quickstart_jax.ipynb for details on using TE's faster fused attention\n",
" x = nn.dot_product_attention(\n",
" query=q,\n",
" key=k,\n",
" value=v,\n",
" mask=attention_mask,\n",
" dropout_rng=dropout_rng,\n",
" dropout_rate=self.attention_dropout,\n",
" deterministic=deterministic,\n",
" broadcast_dropout=True,\n",
" )\n",
" \n",
" # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n",
" x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n",
"\n",
" # Output projection\n",
" x = nn.Dense(features=self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n",
" \n",
" x = res + x\n",
" \n",
" # Second residual connection\n",
" res = x\n",
" x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
" \n",
" # MLP\n",
" mlp = FlaxMLP(\n",
" hidden_size=self.hidden_size,\n",
" ffn_hidden_size=self.ffn_hidden_size,\n",
" dot_general_cls=self.dot_general_cls,\n",
" )\n",
" x = mlp(x)\n",
" \n",
" return x + res\n"
]
},
{
"cell_type": "markdown",
"id": "db16bf70",
"metadata": {},
"source": [
"We've exposed `dot_general_cls` here so we can test out different GEMM implementations later. By default, Flax's `nn.Dense` will use JAX's GEMM `jax.lax.dot_general` when `dot_general` is `None`."
]
},
{
"cell_type": "markdown",
"id": "fbc3510b",
"metadata": {},
"source": [
"## Testing Performance\n",
"\n",
"Now let's test the performance of our FlaxTransformerLayer:\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8b44649d",
"metadata": {},
"outputs": [],
"source": [
"# Layer configuration\n",
"hidden_size = 4096\n",
"sequence_length = 2048\n",
"batch_size = 4\n",
"ffn_hidden_size = 16384\n",
"num_attention_heads = 32\n",
"dtype = jnp.bfloat16\n",
"\n",
"# Synthetic data\n",
"key, dropout_key = jax.random.split(jax.random.PRNGKey(42))\n",
"x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n",
"dy = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e44ed26d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pure Flax FlaxTransformerLayer initialized successfully!\n",
"Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n"
]
}
],
"source": [
"# Initialize the FlaxTransformerLayer\n",
"flax_transformer = FlaxTransformerLayer(\n",
" hidden_size=hidden_size,\n",
" ffn_hidden_size=ffn_hidden_size,\n",
" num_attention_heads=num_attention_heads,\n",
")\n",
"\n",
"# Initialize parameters\n",
"params = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
"print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n",
"print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, params)}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "de91af7a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape: (4, 2048, 4096)\n",
"Output shape: (4, 2048, 4096)\n",
"Output dtype: float32\n",
"Forward pass completed successfully!\n"
]
}
],
"source": [
"# Example usage of forward pass\n",
"y = flax_transformer.apply(params, x, attention_mask=None, deterministic=True)\n",
"print(f\"Input shape: {x.shape}\")\n",
"print(f\"Output shape: {y.shape}\")\n",
"print(f\"Output dtype: {y.dtype}\")\n",
"print(\"Forward pass completed successfully!\")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "037bc8d9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 18.83516788482666 ms\n"
]
}
],
"source": [
"import importlib\n",
"import quickstart_jax_utils\n",
"importlib.reload(quickstart_jax_utils)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=flax_transformer.apply,\n",
" variables=params,\n",
" input=x,\n",
" output_grad=dy,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" rngs={\"dropout\": dropout_key},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "5e9310c9",
"metadata": {},
"source": [
"## Transformer Engine"
]
},
{
"cell_type": "markdown",
"id": "1f8e213e",
"metadata": {},
"source": [
"TransformerEngine/JAX is currently using Flax Linen. However, it is easily compatible with Flax NNX or Haiku.\n",
"* [Use Flax NNX and Linen together](https://flax.readthedocs.io/en/latest/guides/bridge_guide.html)\n",
"* [Haiku and Flax interop](https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html)\n",
"\n",
"Additionally, with the tutorial below, no model parameters need to be managed by TransformerEngine. You can keep all your existing model parameters, initialization, and sharding the same. The only change required is to call TE's dot_general_cls instead of the default Dense dot_general implementation. TE's dot_general_cls is a small module that performs a quantized dense VJP and stores some small recipe-specific state."
]
},
{
"cell_type": "markdown",
"id": "4477d4e9",
"metadata": {},
"source": [
"Now we'll select a recipe. `DelayedScaling` and `CurrentScaling` use per-tensor scaling and are supported on Hopper and Blackwell. `MXFP8BlockScaling` and `NVFP4BlockScaling` use block scaling or a combination of both per-tensor and block scaling and are supported on Blackwell.\n",
"\n",
"If you would like to customize the recipe further, various options can be changed by passing args to the recipe's constructor."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5ddf41e7",
"metadata": {},
"outputs": [],
"source": [
"from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, NVFP4BlockScaling\n",
"from transformer_engine.jax import flax as te_flax \n",
"\n",
"# Choose a quantization recipe. This can be modified to any of the recipes imported above.\n",
"quantization_recipe = DelayedScaling()\n",
"\n",
"te_dot_general_cls = te_flax.make_dot_general_cls(quantization_recipe)\n",
"\n",
"rngs = {'dropout': dropout_key}\n",
"if isinstance(quantization_recipe, NVFP4BlockScaling):\n",
" # The NVFP4 recipe requires a Flax RNG for stochastic rounding\n",
" rngs['sr_rng'] = jax.random.PRNGKey(0)\n"
]
},
{
"cell_type": "markdown",
"id": "c8769655",
"metadata": {},
"source": [
"Now using this quantized dense in our model is as simple as passing in `dot_general_fn=te_dot_general`. Let's try it out!\n",
"\n",
"<div class=\"alert alert-warning\">\n",
"\n",
"<b>Important: Remat Policy</b>\n",
"\n",
"TE's quantization uses specialized TE quantized GEMM primitives. If you are using any built-in JAX checkpoint policies that look for JAX GEMMs (dots), such as `jax.checkpoint_policies.checkpoint_dots`, please replace the policy with `transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms` or similar policies to ensure TE's quantized GEMM primitives are checkpointed correctly.\n",
"\n",
"If this is not performed, TE GEMMs will be rematerialized introducing an incorrect performance comparison.\n",
"\n",
"</div>"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "8407d2ea",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pure Flax FlaxTransformerLayer initialized successfully!\n",
"Parameter shapes: {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}\n",
"Additional state: {'_overwrite_with_gradient': {'FlaxMLP_0': {'TEWrapper_dot_general_0': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}, 'TEWrapper_dot_general_1': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}}, 'TEWrapper_dot_general_0': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}, 'TEWrapper_dot_general_1': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}}}\n"
]
}
],
"source": [
"# Initialize the FlaxTransformerLayer\n",
"flax_transformer = FlaxTransformerLayer(\n",
" hidden_size=hidden_size,\n",
" ffn_hidden_size=ffn_hidden_size,\n",
" num_attention_heads=num_attention_heads,\n",
" dot_general_cls=te_dot_general_cls,\n",
")\n",
"\n",
"# Initialize parameters\n",
"var_collect = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
"print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n",
"print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, var_collect['params'])}\")\n",
"print(f\"Additional state: {jax.tree_util.tree_map(lambda x: x.shape, {k: v for k, v in var_collect.items() if k != 'params'})}\")"
]
},
{
"cell_type": "markdown",
"id": "abe27237",
"metadata": {},
"source": [
"If using a recipe that stores additional state, such as `DelayedScaling`, you'll see this additional state stored as Flax variables. It is important to maintain and pass the whole state of Flax variables `var_collect` across training steps, not just the model params, for proper usage of stateful recipes like `DelayedScaling`.\n",
"\n",
"For example, above inside `Additional state: ` you'll see the `amax_history` of each quantization which is used to compute the per-tensor scale in the `DelayedScaling` recipe."
]
},
{
"cell_type": "markdown",
"id": "5ab72935",
"metadata": {},
"source": [
"The reason we need `te_dot_general_cls` as a Flax module instead of a module-less function like `jax.lax.dot_general` is for some quantization recipes to track internal state separate from model parameters.\n",
"\n",
"Flax modules can manage 3 things:\n",
"1. Model parameters/weights, e.g. your Dense \"kernel\", \"bias\", etc.\n",
"2. RNGs for dropout, stochastic rounding, etc.\n",
"3. Flax variables. These are additional state variables that are used across training steps but are distinct from model params in that you don't take gradients or optimize them. Currently, we only use this for DelayedScaling's amax_history state\n",
"\n",
"With the simplest quantization integration shown in this tutorial, we want users to keep their existing model param setup so they don't need to worry about preserving the sharding, init distribution, etc.. So we don't need point 1 since we don't do model param creation in this codepath with dot_general_cls, but we still do need `te_dot_general_cls()` to produce a Flax module since we potentially need to do points 2 or 3 which need to be in a Flax module."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "3b6b344b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape: (4, 2048, 4096)\n",
"Output shape: (4, 2048, 4096)\n",
"Output dtype: float32\n",
"Forward pass completed successfully!\n"
]
}
],
"source": [
"# Example usage of forward pass\n",
"y = flax_transformer.apply(var_collect, x, attention_mask=None, deterministic=True, rngs=rngs)\n",
"print(f\"Input shape: {x.shape}\")\n",
"print(f\"Output shape: {y.shape}\")\n",
"print(f\"Output dtype: {y.dtype}\")\n",
"print(\"Forward pass completed successfully!\")\n"
]
},
{
"cell_type": "markdown",
"id": "d178f247",
"metadata": {},
"source": [
"Now let's measure the performance!"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5cc6c2a7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 10.553865432739258 ms\n"
]
}
],
"source": [
"import importlib\n",
"import quickstart_jax_utils\n",
"importlib.reload(quickstart_jax_utils)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=flax_transformer.apply,\n",
" variables=var_collect,\n",
" input=x,\n",
" output_grad=dy,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" rngs=rngs,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment