Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Introduction
===================================
Transformer Engine accelerates deep learning on NVIDIA GPUs in several ways,
with low precision training being one of the most important.
This chapter introduces mixed precision training and FP8 support.
Training in BF16/FP16
---------------------
Deep learning traditionally uses 32-bit floating-point (FP32) numbers.
NVIDIA GPUs support lower precision formats—FP16 since Pascal, BF16 since Ampere—which offer higher throughput and lower memory usage.
Let's compare these formats.
.. raw:: html
:file: img/fp_formats_comparison.svg
*Figure 1: Comparison of FP32, BF16, and FP16 floating-point formats showing bit allocation for sign, exponent, and mantissa.*
The key differences between these formats are:
* **FP32** (32 bits total): 1 sign bit + 8 exponent bits + 23 mantissa bits – standard single-precision format
* **BF16** (16 bits total): 1 sign bit + 8 exponent bits + 7 mantissa bits – maintains FP32's exponent range but has reduced precision
* **FP16** (16 bits total): 1 sign bit + 5 exponent bits + 10 mantissa bits – reduced range but higher precision than BF16
BF16's advantage is that it shares the same exponent range as FP32,
making it easier to convert between the two formats without overflow/underflow issues.
FP16 offers better precision for smaller values but has a limited dynamic range,
which results in the need to perform loss scaling to avoid overflow/underflow—see `this paper on loss scaling <https://arxiv.org/pdf/1710.03740>`__ for more details.
**Mixed precision**
Not all operations should be run in reduced precision to preserve accuracy.
Modern deep learning frameworks use *mixed precision training*,
where different operations use different precisions based on their numerical properties:
* Matrix multiplications are compute-heavy and remain numerically stable at lower precision, making them ideal candidates for acceleration.
* Operations like layer normalization and softmax can work with low precision inputs and outputs, but may use high precision internally or for their weights.
* Operations like loss computation and exponentiation need high precision throughout.
**Master weights**
Another consideration in mixed precision training is how to store the model weights.
Lower precision formats like FP16 and BF16 have limited representational granularity,
which becomes problematic during gradient updates.
When a small gradient is added to a not so small weight stored in low precision,
the result may round back to the original value if the update falls below the format's precision threshold.
Moreover, some elements of the gradient itself can be too small to be represented in low precision,
especially after the accumulation from multiple GPUs in the data parallel training setting.
The solution is to maintain *master weights* in FP32.
During training, weights are cast to lower precision for forward and backward passes,
but the gradient updates are applied to the full-precision master copy.
This ensures that even small gradients accumulate correctly over time.
There are two common software approaches to storing master weights:
* *In the optimizer*:
The model holds low-precision weights,
while the optimizer maintains FP32 copies alongside momentum and other state.
During each step,
the optimizer updates its FP32 copy and casts the result back to the model's low-precision weights.
This approach makes it easier to shard master weights together with other optimizer state, for example in ZeRO optimizer.
Since the casting happens only during the optimizer step, this approach is also faster when optimizer runs less frequently than the model, e.g. when performing gradient accumulation or pipeline parallel training.
* *In the model*:
The model stores weights directly in FP32,
and they are cast to lower precision on-the-fly during forward and backward passes.
This approach works seamlessly with any standard optimizer, requiring no special support.
.. raw:: html
:file: img/master_weights_approaches.svg
*Figure 2: Three approaches to weight storage—low precision only (no master weights), master weights stored in the model, and master weights stored in the optimizer.*
.. tabs::
.. tab:: PyTorch
The PyTorch API of Transformer Engine provides several mechanisms to control precision:
* **Weight precision**: Use the ``params_dtype`` argument in any TE layer constructor.
* **Computation precision**: Use the ``torch.autocast`` context manager. When enabled, inputs are cast to the autocast dtype before computation.
* **Input dtype**: When ``torch.autocast`` is not used, the input tensor's dtype determines the computation precision. In this case, inputs and parameters must have matching dtypes.
.. literalinclude:: bf16_fp16_training_pytorch.py
:language: python
:start-after: # START_BF16_FP16_TRAINING
:end-before: # END_BF16_FP16_TRAINING
.. tab:: JAX
The JAX API of Transformer Engine provides two mechanisms to control precision:
* **Weight precision**: Use the ``dtype`` argument in any TE layer constructor.
* **Computation precision**: Determined by the dtype of the input tensor.
For training with master weights in FP32 and computation in BF16,
cast the input tensor to BF16 before passing it to the layer.
.. literalinclude:: bf16_fp16_training_jax.py
:language: python
:start-after: # START_BF16_FP16_TRAINING
:end-before: # END_BF16_FP16_TRAINING
Lower precisions
----------------
Transformer Engine's primary feature is supporting even lower precision than BF16/FP16, such as FP8, MXFP8, NVFP4, etc.
The logic of these precisions is more complicated than the logic of BF16/FP16 – they require scaling factors to
properly represent the full range of values in the tensor. Sometimes it is one scaling factor per tensor,
sometimes it is one scaling factor per block of values. A precision format combined with the logic for training
is called **a recipe**.
In this section we present common logic for all the recipes. Each one of them is described in more detail in a separate section later.
Let's now see how we can train in lower precisions in supported frameworks.
.. tabs::
.. tab:: PyTorch
The PyTorch API of Transformer Engine provides an ``autocast`` context manager to control precision.
It's similar to the ``torch.autocast`` context manager, but tailored for low precision training.
The most important argument is the ``recipe`` argument, which accepts objects inheriting from
Python context managers like ``autocast`` may interact unexpectedly with JAX's JIT compilation.
For finer-grained control, consider passing the recipe directly to TE modules instead.
See the `TE JAX Integration notebook <https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_jax_integration.ipynb>`_
for details.
**Mixed precision with 8- or 4-bit precisions**
From now on, we will refer to FP8/MXFP8/NVFP4 etc. as *low precision*
and to FP32/BF16/FP16 as *high precision*. This terminology will be
used throughout the rest of the documentation.
Not all operations run in low precision:
- **Linear operations**: run in low precision.
- **Attention computations**: run in high precision by default (some recipes allow low precision as an option).
- **Other operations** (layer normalization, softmax, etc.): run in high precision.
Within high-precision operations, there are two categories:
- **Configurable precision**: most operations run in parameter precision (FP32/BF16/FP16) or the precision specified by ``torch.autocast``.
- **Fixed FP32 precision**: some operations, or parts of operations—such as the division in layernorm—always run in FP32, regardless of other settings.
.. raw:: html
:file: img/mixed_precision_operations.svg
*Figure 3: Default precision of operations in a TransformerLayer forward pass. Only linear operations are in lower precision. Dot product attention is shown as three separate operations (QK^T, Softmax, Scores * V), though in practice these may be fused into a single kernel.*
**Linear layer data flow**
Let's see how data flow of a linear layer works by default on a single H100 GPU with FP8 precision:
H100 (Hopper) architecture natively supports FP8 Matrix Multiplication only in **TN** layout (Transpose-NoTranspose),
so GEMM with tensors ``A`` and ``B`` returns ``B * A^T``.
*Forward pass*
* Input is quantized to FP8 – both ``input`` and ``input^T`` quantized versions are created.
* Weights are stored in high precision and quantized to low precision before the GEMM – both ``weight`` and ``weight^T`` quantized versions are created.
* FP8 GEMM with layout **TN** is run with ``weight`` and ``input`` tensors,
* Outputs – ``input * weight^T`` tensor – are returned in high precision.
*Backward pass*
* Output gradients are quantized to FP8 – both ``output_grad`` and ``output_grad^T`` quantized versions are created.
* FP8 GEMM with layout **TN** is performed with ``weight^T`` and ``output_grad`` tensors to compute input gradients.
* FP8 GEMM with layout **TN** is performed with ``input^T`` and ``output_grad^T`` tensors to compute weight gradients.
* Input gradients – ``output_grad * weight`` tensor – are returned in high precision.
* Weight gradients – ``output_grad^T * input`` tensor – are returned in high precision.
.. raw:: html
:file: img/fp8_linear_flow.svg
*Figure 4: Forward pass of a Linear layer with low precision data flow.*
does not require explicit transposition. However, rowwise and columnwise quantizations are different:
- *Rowwise* - 1 scaling factor per 32 consecutive elements along a row (1×32 blocks).
- *Columnwise* - 1 scaling factor per 32 consecutive elements along a column (32×1 blocks).
Since the scaling factor blocks have different orientations, rowwise and columnwise MXFP8 tensors
are numerically different — one cannot derive one from the other. Both must be quantized
independently from the full-precision data.
.. raw:: html
:file: img/mxfp8_row_col.svg
*Figure 2. MXFP8 rowwise vs columnwise quantization layout.*
Distributed training
--------------------
**Scale synchronization**
The blockwise scaled tensor does not need any scale synchronization among the nodes.
This is because each scaling factor is local to its 32-element block,
unlike :doc:`FP8 Current <../fp8_current_scaling/fp8_current_scaling>`/:doc:`Delayed Scaling <../fp8_delayed_scaling/fp8_delayed_scaling>` where a single global scale applies to the entire tensor, even when sharded.
**Quantized all-gather**
MXFP8 all-gather is supported.
Examples
--------
Here's how to use MXFP8 recipe in PyTorch and JAX:
This section contains implementation details that may be useful for developers
but are not required for using MXFP8 in practice.
Swizzling scaling factors
^^^^^^^^^^^^^^^^^^^^^^^^^
Like :doc:`FP8 Blockwise Scaling <../fp8_blockwise_scaling/fp8_blockwise_scaling>`, MXFP8 uses different data layouts for communication and computation.
MXFP8 GEMMs require scaling factors in a specific hardware layout
(see `cuBLAS documentation <https://docs.nvidia.com/cuda/cublas/index.html#block-scaling-factors-layout>`__).
The conversion to this GEMM-ready layout is called *swizzling*. When no communication is needed,
swizzling can be fused with quantization. When communication is required, swizzled scaling factors
cannot be communicated across devices, so Transformer Engine performs swizzling after communication,
just before each GEMM operation.
.. raw:: html
:file: img/mxfp8_swizzle_both_tensors.svg
*Figure 3. MXFP8 swizzling process: standard scaling factors are rearranged into the hardware-required layout.*
Blackwell Tensor Cores compute matrix multiplications using ``128x128`` tiles.
Scaling factors are stored in row-major order, but to process a tile, we need a ``128x4`` vertical
slice of scaling factors. In row-major storage, these vertical slices are scattered in memory
with gaps between each row. The hardware requires them to be stored contiguously.
.. raw:: html
:file: img/mxfp8_tensor_scaling_layout.svg
*Figure 4. FP8 tensor (left) is divided into 128x128 tiles. Each tile requires a 128x4 block of scaling factors (right). These vertical blocks are not contiguous in memory.*
Swizzling transforms the layout to meet hardware requirements by:
1. **Linearizing** the ``128x4`` blocks so they are stored contiguously one after another.
2. **Permuting** the 4-byte elements within each block.
Specifically, if we index the 128 4-byte elements in a scaling factor block as :math:`0, 1, \dots, 127`, the hardware expects them in the following interleaved order:
*Figure 5. Linearization and swizzling of scaling factors. The 2D grid of scaling factors is first flattened into a contiguous sequence of blocks (top), then the rows within each block are interleaved to match the hardware access pattern (bottom).*
For columnwise scaling factors, the process is analogous but with ``4x128`` horizontal blocks instead of ``128x4`` vertical blocks.
All-gather of columnwise tensors
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
All-gather of columnwise tensors is supported and necessary because:
- columnwise quantized tensors cannot be computed from rowwise quantized ones,
- gathering high-precision tensors is avoided in most cases for performance reasons.
rounds to the nearest value, while stochastic rounding probabilistically rounds to either value.
If* ``x`` *is 40% of the way from* ``v1`` *to* ``v2``, *there is a 60% chance of rounding to* ``v1``
*and a 40% chance of rounding to* ``v2``.
Stochastic rounding is enabled only for gradients. It can be disabled by setting
``disable_stochastic_rounding=True`` in the recipe configuration.
Random Hadamard Transform
--------------------------
Random Hadamard Transform (RHT) applies an orthogonal rotation to the tensor **before quantization**,
smoothing outliers in the tensor distributions and making them easier to represent accurately in NVFP4.
RHT is applied to columnwise quantization of inputs and gradients, which are operands
for the **wgrad GEMM**. This GEMM is particularly sensitive
to quantization errors, hence the additional outlier smoothing.
RHT is supported only for BF16 inputs/gradients.
The transform is defined as:
.. math::
x' = x H
where :math:`H` is the RHT matrix defined below. The quantization scale factor is computed
from the rotated tensor :math:`x'`.
**Hadamard matrix**
The :math:`d \times d` Hadamard matrix has elements :math:`\pm 1` and satisfies :math:`H_d H_d^T = d I`.
When normalized by :math:`1/\sqrt{d}`, the matrix becomes orthogonal and can be applied
to both operands of a matrix multiplication:
.. math::
C = (AH)(H^T B) = AB
where the transforms cancel within the dot-product since :math:`H H^T = I`.
**Sign matrix**
In the RHT implementation, a :math:`d`-dimensional diagonal sign matrix :math:`S_d` is applied
together with the Hadamard matrix:
.. math::
H = \frac{1}{\sqrt{d}} S_d H_d
where diagonal entries of :math:`S_d` are :math:`\{-1, 1\}` and flip the signs of different rows of :math:`H_d`.
As described in the paper, a single random sign vector is shared across all linear layers throughout training.
In the implementation, this vector is fixed and the RHT matrix is computed once at initialization and cached.
**Tiled implementation**
The Hadamard transform is performed in a tiled approach along the last dimension of the tensor.
For an :math:`m \times k` tensor, the data is reshaped to :math:`(mk/d) \times d`
and multiplied by the :math:`d \times d` matrix :math:`H`. In this implementation, :math:`d = 16`.
.. raw:: html
:file: img/rht.svg
*Figure 4. WGRAD GEMM pipeline comparison: without RHT (left) and with RHT applied (right).*
Handling transposes
-------------------
Like :doc:`MXFP8 <../mxfp8/mxfp8>`, NVFP4 requires both rowwise and columnwise quantized tensors
for different GEMM operands. Unlike MXFP8 which supports multiple layouts (TN, NT, NN),
**NVFP4 GEMM only supports the TN layout**.
NVFP4 stores columnwise data and scaling factors in a **transposed layout**:
- **Rowwise**: data ``[A, B]`` with 1×16 horizontal blocks, ``scales`` shape ``[A, B/16]``
- **Columnwise**: data ``[B, A]`` (transposed) with 1×16 horizontal blocks, ``scales`` shape ``[B, A/16]``
Scale tensors are padded for hardware alignment: first dimension to a multiple of 128,
second dimension to a multiple of 4 (e.g. rowwise: ``[roundup(A, 128), roundup(B/16, 4)]``).
.. raw:: html
:file: img/nvfp4_row_col.svg
*Figure 5. NVFP4 rowwise vs columnwise quantization layout. Unlike MXFP8, columnwise scales are stored transposed.*
Distributed training
--------------------
**Amax reduction**
Block scaling factors (``s_block``) do not require synchronization between nodes,
as each scaling factor is local to its block of 16 elements.
However, the global scaling factor (``s_global``) requires amax synchronization for gathered tensors.
For tensors that are gathered (e.g., input and gradient in sequence parallelism),
amax reduction is performed before quantization.
If before synchronization there was ``amax_1`` on node 1,
``amax_2`` on node 2, etc., after synchronization there will be ``max(amax_1, amax_2, ...)`` on all nodes.
**Quantized all-gather**
NVFP4 all-gather is supported.
.. raw:: html
:file: img/nvfp4_all_gather.svg
*Figure 6. Quantization and all-gather flow for NVFP4 showing amax synchronization and hierarchical scaling.*
Examples
--------
Here's how to use NVFP4 recipe in PyTorch and JAX. The examples show how to configure features like 2D weight quantization and Random Hadamard Transform (RHT):