Unverified Commit 3ceb248e authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

More detailed documentation for recipes (#2343)



* Code drop: Update recipes documentation and remove custom recipes from low precision training
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Fix SVG css import path for diagrams
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Refactor low_precision_training docs: remove optimizers, fix imports, add GPU checks

Changes:
- Remove optimizer code from all recipe examples (keep only forward/backward)
- Fix Format imports (use Format.E4M3 instead of string 'E4M3')
- Fix params_dtype for PyTorch examples (add params_dtype=torch.bfloat16)
- Add GPU capability assertions before START blocks for blockwise/mxfp8/nvfp4
- Fix JAX imports (Float8CurrentScaling from common.recipe, NVFP4BlockScaling)
- Add global_shard_guard for TransformerLayer examples in JAX
- Fix fused_layers_jax.py return tuple unpacking
- Update memory_usage JAX examples with dynamic GPU measurement
- Remove memory_usage_3_jax (JAX doesn't support FP8 weight storage)
- Update performance_considerations.rst for JAX differences
- Delete unused .out files and fp8_autocast_jax.py
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix JAX memory usage .out files with correct output
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* responded to comments
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* applied suggestions form greptile
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* year change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* jax compute capability fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c3769cb7
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from transformer_engine.jax.quantize import get_device_compute_capability
# Requires Ada (SM89) or newer for FP8 support
assert get_device_compute_capability() >= 89, "This example requires SM89 (Ada) or newer"
# START_DELAYED_SCALING_EXAMPLE
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import DenseGeneral
from transformer_engine.common.recipe import DelayedScaling
# Create FP8 Delayed Scaling recipe
recipe = DelayedScaling(
margin=0, # Margin for scaling factor computation (default: 0)
amax_history_len=1024, # Length of amax history window (default: 1024)
amax_compute_algo="max", # How to compute amax from history (default: "max")
)
with te.autocast(enabled=True, recipe=recipe):
# Initialize layer and data
layer = DenseGeneral(features=1024)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
var_collect = layer.init(key, x)
# Forward and backward pass
def loss_fn(var_collect):
output = layer.apply(var_collect, x)
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
# END_DELAYED_SCALING_EXAMPLE
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_AMAX_REDUCTION_EXAMPLE
import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
# Create process group for amax reduction (e.g., all 8 GPUs)
amax_reduction_group = dist.new_group(ranks=[0, 1, 2, 3, 4, 5, 6, 7])
recipe = DelayedScaling(reduce_amax=True)
with te.autocast(recipe=recipe, amax_reduction_group=amax_reduction_group):
output = model(inp)
# END_AMAX_REDUCTION_EXAMPLE
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Requires Ada (SM89) or newer for FP8 support
assert torch.cuda.get_device_capability()[0] >= 9 or (
torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9
), "This example requires SM89 (Ada) or newer"
# START_DELAYED_SCALING_EXAMPLE
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
# Create FP8 Delayed Scaling recipe
recipe = DelayedScaling(
margin=0, # Margin for scaling factor computation (default: 0)
amax_history_len=1024, # Length of amax history window (default: 1024)
amax_compute_algo="max", # How to compute amax from history (default: "max")
)
# Create a linear layer with bfloat16 parameters
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
# Forward and backward pass
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
loss = output.sum()
loss.backward()
# END_DELAYED_SCALING_EXAMPLE
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Low precision training
===================================
.. toctree::
introduction/introduction.rst
performance_considerations/performance_considerations.rst
fp8_current_scaling/fp8_current_scaling.rst
fp8_delayed_scaling/fp8_delayed_scaling.rst
fp8_blockwise_scaling/fp8_blockwise_scaling.rst
mxfp8/mxfp8.rst
nvfp4/nvfp4.rst
\ No newline at end of file
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from transformer_engine.jax.quantize import get_device_compute_capability
# Requires Ada (SM89) or newer for FP8 support
assert get_device_compute_capability() >= 89, "This example requires SM89 (Ada) or newer"
# START_AUTOCAST_BASIC
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import TransformerLayer
from transformer_engine.common.recipe import DelayedScaling, Format
# Set up recipe
recipe = DelayedScaling()
# Model initialization must happen inside autocast
with te.autocast(enabled=True, recipe=recipe):
layer = TransformerLayer(
hidden_size=1024,
mlp_hidden_size=4096,
num_attention_heads=16,
)
init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(init_key, (32, 128, 1024), dtype=jnp.bfloat16)
var_collect = layer.init({"params": init_key, "dropout": dropout_key}, x)
# Forward and backward pass (both inside autocast for JAX)
def loss_fn(var_collect):
output = layer.apply(var_collect, x, rngs={"dropout": dropout_key})
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
# END_AUTOCAST_BASIC
# START_AUTOCAST_SEQUENTIAL
encoder_recipe = DelayedScaling(fp8_format=Format.E4M3)
decoder_recipe = DelayedScaling(fp8_format=Format.HYBRID)
with te.autocast(enabled=True, recipe=encoder_recipe):
encoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
encoder_var_collect = encoder.init({"params": init_key, "dropout": dropout_key}, x)
hidden = encoder.apply(encoder_var_collect, x, rngs={"dropout": dropout_key})
with te.autocast(enabled=True, recipe=decoder_recipe):
decoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
decoder_var_collect = decoder.init({"params": init_key, "dropout": dropout_key}, hidden)
output = decoder.apply(decoder_var_collect, hidden, rngs={"dropout": dropout_key})
# END_AUTOCAST_SEQUENTIAL
# START_AUTOCAST_NESTED
outer_recipe = DelayedScaling(fp8_format=Format.E4M3)
inner_recipe = DelayedScaling(fp8_format=Format.HYBRID)
with te.autocast(enabled=True, recipe=outer_recipe):
# layer1 uses outer_recipe
layer1 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
var_collect1 = layer1.init({"params": init_key, "dropout": dropout_key}, x)
hidden = layer1.apply(var_collect1, x, rngs={"dropout": dropout_key})
with te.autocast(enabled=True, recipe=inner_recipe):
# layer2 uses inner_recipe (overrides outer)
layer2 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
var_collect2 = layer2.init({"params": init_key, "dropout": dropout_key}, hidden)
hidden = layer2.apply(var_collect2, hidden, rngs={"dropout": dropout_key})
# layer3 uses outer_recipe again
layer3 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
var_collect3 = layer3.init({"params": init_key, "dropout": dropout_key}, hidden)
output = layer3.apply(var_collect3, hidden, rngs={"dropout": dropout_key})
# END_AUTOCAST_NESTED
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Requires Ada (SM89) or newer for FP8 support
assert torch.cuda.get_device_capability()[0] >= 9 or (
torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9
), "This example requires SM89 (Ada) or newer"
# START_AUTOCAST_BASIC
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format
recipe = DelayedScaling()
layer = te.Linear(1024, 1024)
inp = torch.randn(32, 1024, dtype=torch.float32, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
# .backward() is called outside of autocast
loss = output.sum()
loss.backward()
# END_AUTOCAST_BASIC
# START_AUTOCAST_SEQUENTIAL
encoder_recipe = DelayedScaling(fp8_format=Format.E4M3)
decoder_recipe = DelayedScaling(fp8_format=Format.HYBRID)
encoder = te.Linear(1024, 1024)
decoder = te.Linear(1024, 1024)
with te.autocast(enabled=True, recipe=encoder_recipe):
hidden = encoder(inp)
with te.autocast(enabled=True, recipe=decoder_recipe):
output = decoder(hidden)
# END_AUTOCAST_SEQUENTIAL
# START_AUTOCAST_NESTED
outer_recipe = DelayedScaling(fp8_format=Format.E4M3)
inner_recipe = DelayedScaling(fp8_format=Format.HYBRID)
layer1 = te.Linear(1024, 1024)
layer2 = te.Linear(1024, 1024)
layer3 = te.Linear(1024, 1024)
with te.autocast(enabled=True, recipe=outer_recipe):
# layer1 uses outer_recipe
x = layer1(inp)
with te.autocast(enabled=True, recipe=inner_recipe):
# layer2 uses inner_recipe (overrides outer)
x = layer2(x)
# layer3 uses outer_recipe again
output = layer3(x)
# END_AUTOCAST_NESTED
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_BF16_FP16_TRAINING
import jax
import jax.numpy as jnp
from transformer_engine.jax.flax import TransformerLayer
def run_forward_backward(params_dtype, compute_dtype):
# Create TransformerLayer
layer = TransformerLayer(
hidden_size=1024,
mlp_hidden_size=4096,
num_attention_heads=16,
dtype=params_dtype,
)
# Initialize parameters and optimizer
init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(init_key, (32, 128, 1024), dtype=compute_dtype)
var_collect = layer.init({"params": init_key, "dropout": dropout_key}, x)
# Forward and backward pass
def loss_fn(var_collect):
output = layer.apply(var_collect, x, rngs={"dropout": dropout_key})
assert output.dtype == compute_dtype
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
run_forward_backward(jnp.float32, jnp.float32) # high precision training
run_forward_backward(jnp.float32, jnp.bfloat16) # bfloat16 training with master weights in FP32
run_forward_backward(jnp.bfloat16, jnp.bfloat16) # bfloat16 training with weights in BF16
# END_BF16_FP16_TRAINING
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_BF16_FP16_TRAINING
import torch
import transformer_engine.pytorch as te
from contextlib import nullcontext
def run_forward_backward(params_dtype, autocast_precision, grad_scaler_enabled):
if grad_scaler_enabled:
grad_scaler = torch.amp.GradScaler("cuda")
layer = te.TransformerLayer(
hidden_size=1024,
ffn_hidden_size=4096,
num_attention_heads=16,
params_dtype=params_dtype,
)
x = torch.randn(32, 128, 1024, dtype=params_dtype, device="cuda")
autocast_ctx = (
torch.autocast(device_type="cuda", dtype=autocast_precision)
if autocast_precision is not None
else nullcontext()
)
with autocast_ctx:
output = layer(x)
assert (
output.dtype == autocast_precision if autocast_precision is not None else params_dtype
)
loss = output.sum()
if grad_scaler_enabled:
grad_scaler.scale(loss).backward()
else:
loss.backward()
run_forward_backward(torch.float32, torch.float32, False) # high precision training
run_forward_backward(
torch.float32, torch.bfloat16, False
) # bfloat16 training with master weights in FP32
run_forward_backward(
torch.float32, torch.float16, True
) # fp16 training with master weights in FP32, needs loss scaling
run_forward_backward(
torch.bfloat16, torch.bfloat16, False
) # bfloat16 training with weights in BF16
# END_BF16_FP16_TRAINING
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 850 780" width="850" height="780">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
</style>
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- Title -->
<text x="425" y="30" class="title">FP8 Linear Layer – Forward and Backward Pass</text>
<!-- Forward Pass Section -->
<text x="425" y="65" class="section-title" style="fill: #1565c0;">Forward Pass</text>
<!-- Forward: Input^T FP8 (top, saved for backward) -->
<rect x="270" y="70" width="80" height="50" class="fp8" rx="6"/>
<text x="310" y="100" class="text">Input<tspan baseline-shift="super" style="font-size: 9px;">T</tspan></text>
<!-- Forward: Input High Precision -->
<rect x="30" y="130" width="100" height="50" class="hp" rx="6"/>
<text x="80" y="160" class="text">Input</text>
<!-- Forward: Arrow -->
<path d="M 130 155 L 155 155" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Forward: Quantize Input -->
<rect x="155" y="130" width="90" height="50" class="quantize" rx="6"/>
<text x="200" y="160" class="text">Quantize</text>
<!-- Forward: Arrow to Input^T (going up) -->
<path d="M 245 140 L 270 110" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Forward: Arrow to Input -->
<path d="M 245 155 L 270 155" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Forward: Input FP8 -->
<rect x="270" y="130" width="80" height="50" class="fp8" rx="6"/>
<text x="310" y="160" class="text">Input</text>
<!-- Forward: Arrow from Input to GEMM -->
<path d="M 350 155 L 400 170" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<text x="370" y="145" class="text" style="font-size: 10px; font-weight: bold;">N</text>
<!-- Forward: Weights High Precision -->
<rect x="30" y="195" width="100" height="50" class="hp" rx="6"/>
<text x="80" y="225" class="text">Weight</text>
<!-- Forward: Arrow -->
<path d="M 130 220 L 155 220" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Forward: Quantize Weights -->
<rect x="155" y="195" width="90" height="50" class="quantize" rx="6"/>
<text x="200" y="225" class="text">Quantize</text>
<!-- Forward: Arrow to Weight -->
<path d="M 245 220 L 270 220" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Forward: Arrow to Weight^T (going down) -->
<path d="M 245 235 L 270 270" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Forward: Weights FP8 -->
<rect x="270" y="195" width="80" height="50" class="fp8" rx="6"/>
<text x="310" y="225" class="text">Weight</text>
<!-- Forward: Weight^T FP8 (bottom, saved for backward) -->
<rect x="270" y="255" width="80" height="50" class="fp8" rx="6"/>
<text x="310" y="285" class="text">Weight<tspan baseline-shift="super" style="font-size: 9px;">T</tspan></text>
<!-- Forward: Arrow from Weight to GEMM -->
<path d="M 350 220 L 400 200" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<text x="370" y="230" class="text" style="font-size: 10px; font-weight: bold;">T</text>
<!-- Forward: GEMM -->
<rect x="400" y="160" width="130" height="50" class="gemm" rx="6"/>
<text x="465" y="180" class="text" style="font-weight: 600;">FP8 GEMM</text>
<text x="465" y="200" class="text" style="font-size: 11px;">(TN)</text>
<!-- Forward: Arrow -->
<path d="M 530 185 L 580 185" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Forward: Output -->
<rect x="580" y="160" width="110" height="50" class="hp" rx="6"/>
<text x="635" y="190" class="text">Output</text>
<!-- Divider Line -->
<line x1="30" y1="310" x2="820" y2="310" stroke="#ddd" stroke-width="2"/>
<!-- Backward Pass Section -->
<text x="425" y="345" class="section-title" style="fill: #c62828;">Backward Pass</text>
<!-- Backward: Weight^T (from forward, top input to GEMM1) -->
<rect x="495" y="355" width="80" height="50" class="fp8" rx="6"/>
<text x="535" y="385" class="text">Weight<tspan baseline-shift="super" style="font-size: 9px;">T</tspan></text>
<!-- Backward: Output gradient High Precision -->
<rect x="30" y="480" width="130" height="50" class="hp" rx="6"/>
<text x="95" y="510" class="text">Output grad.</text>
<!-- Backward: Arrow -->
<path d="M 160 505 L 180 505" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Backward: Quantize Output gradient -->
<rect x="180" y="480" width="90" height="50" class="quantize" rx="6"/>
<text x="225" y="510" class="text">Quantize</text>
<!-- Backward: Arrow to Output grad (going up) -->
<path d="M 270 490 L 290 465" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Backward: Arrow to Output grad^T (going down) -->
<path d="M 270 520 L 290 545" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Backward: Output gradient FP8 (for input gradient) -->
<rect x="290" y="440" width="110" height="50" class="fp8" rx="6"/>
<text x="345" y="470" class="text">Output grad.</text>
<!-- Backward: Output gradient^T FP8 (for weight gradient) -->
<rect x="290" y="520" width="110" height="50" class="fp8" rx="6"/>
<text x="345" y="550" class="text">Output grad.<tspan baseline-shift="super" style="font-size: 9px;">T</tspan></text>
<!-- Backward: GEMM 1 (for input gradient) -->
<rect x="470" y="440" width="130" height="50" class="gemm" rx="6"/>
<text x="535" y="460" class="text" style="font-weight: 600;">FP8 GEMM</text>
<text x="535" y="480" class="text" style="font-size: 11px;">(TN)</text>
<!-- Backward: Input gradient -->
<rect x="640" y="440" width="130" height="50" class="hp" rx="6"/>
<text x="705" y="470" class="text">Input grad.</text>
<!-- Backward: GEMM 2 (for weight gradient) -->
<rect x="470" y="520" width="130" height="50" class="gemm" rx="6"/>
<text x="535" y="540" class="text" style="font-weight: 600;">FP8 GEMM</text>
<text x="535" y="560" class="text" style="font-size: 11px;">(TN)</text>
<!-- Backward: Weight gradient -->
<rect x="640" y="520" width="130" height="50" class="hp" rx="6"/>
<text x="705" y="550" class="text">Weight grad.</text>
<!-- Backward: Input^T (from forward, bottom input to GEMM2) -->
<rect x="495" y="605" width="80" height="50" class="fp8" rx="6"/>
<text x="535" y="635" class="text">Input<tspan baseline-shift="super" style="font-size: 9px;">T</tspan></text>
<!-- Backward: Arrows -->
<!-- Output gradient FP8 to top GEMM -->
<path d="M 400 465 L 470 465" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<text x="430" y="457" class="text" style="font-size: 10px; font-weight: bold;">N</text>
<!-- Weight^T to top GEMM -->
<path d="M 535 405 L 535 440" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<text x="543" y="427" class="text" style="font-size: 10px; font-weight: bold;">T</text>
<!-- Top GEMM to input gradient -->
<path d="M 600 465 L 640 465" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Output gradient^T FP8 to bottom GEMM -->
<path d="M 400 545 L 470 545" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<text x="430" y="537" class="text" style="font-size: 10px; font-weight: bold;">N</text>
<!-- Input^T to bottom GEMM -->
<path d="M 535 605 L 535 570" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<text x="543" y="597" class="text" style="font-size: 10px; font-weight: bold;">T</text>
<!-- Bottom GEMM to weight gradient -->
<path d="M 600 545 L 640 545" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Legend -->
<g transform="translate(30, 680)">
<!-- Higher Precision -->
<rect x="0" y="0" width="80" height="40" rx="5" class="hp"/>
<text x="95" y="23" class="text" style="text-anchor: start;">Higher Precision (FP32/BF16/FP16)</text>
<!-- Lower Precision -->
<rect x="380" y="0" width="80" height="40" rx="5" class="fp8"/>
<text x="475" y="23" class="text" style="text-anchor: start;">Lower Precision (FP8, MXFP8 etc.)</text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 900 210">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.sign-bit { fill: #9db4d0; stroke: #333; stroke-width: 1; }
.exponent-bit { fill: #d9a066; stroke: #333; stroke-width: 1; }
.mantissa-bit { fill: #a8d99c; stroke: #333; stroke-width: 1; }
.bit-text { fill: #000; text-anchor: middle; dominant-baseline: middle; font-size: 16px; }
.header-text { fill: #555; font-weight: normal; text-anchor: middle; font-size: 18px; }
.value-text { fill: #333; font-size: 18px; }
.format-label { fill: #333; font-size: 20px; font-weight: bold; text-anchor: middle; dominant-baseline: middle; }
</style>
</defs>
<!-- Header labels - centered -->
<text x="79" y="18" class="header-text">sign</text>
<text x="173" y="18" class="header-text">exponent</text>
<text x="530" y="18" class="header-text">mantissa</text>
<!-- FP32 Format (32 bits: 1 + 8 + 23) -->
<text x="30" y="60" class="format-label">FP32</text>
<!-- Sign bit (1) -->
<rect x="70" y="45" width="18" height="30" class="sign-bit"/>
<text x="79" y="60" class="bit-text">0</text>
<!-- Exponent bits (8) -->
<rect x="93" y="45" width="18" height="30" class="exponent-bit"/>
<text x="102" y="60" class="bit-text">0</text>
<rect x="116" y="45" width="18" height="30" class="exponent-bit"/>
<text x="125" y="60" class="bit-text">1</text>
<rect x="139" y="45" width="18" height="30" class="exponent-bit"/>
<text x="148" y="60" class="bit-text">1</text>
<rect x="162" y="45" width="18" height="30" class="exponent-bit"/>
<text x="171" y="60" class="bit-text">1</text>
<rect x="185" y="45" width="18" height="30" class="exponent-bit"/>
<text x="194" y="60" class="bit-text">1</text>
<rect x="208" y="45" width="18" height="30" class="exponent-bit"/>
<text x="217" y="60" class="bit-text">1</text>
<rect x="231" y="45" width="18" height="30" class="exponent-bit"/>
<text x="240" y="60" class="bit-text">0</text>
<rect x="254" y="45" width="18" height="30" class="exponent-bit"/>
<text x="263" y="60" class="bit-text">1</text>
<!-- Mantissa bits (23) -->
<rect x="277" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="286" y="60" class="bit-text">1</text>
<rect x="300" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="309" y="60" class="bit-text">0</text>
<rect x="323" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="332" y="60" class="bit-text">0</text>
<rect x="346" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="355" y="60" class="bit-text">1</text>
<rect x="369" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="378" y="60" class="bit-text">0</text>
<rect x="392" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="401" y="60" class="bit-text">1</text>
<rect x="415" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="424" y="60" class="bit-text">0</text>
<rect x="438" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="447" y="60" class="bit-text">0</text>
<rect x="461" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="470" y="60" class="bit-text">1</text>
<rect x="484" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="493" y="60" class="bit-text">0</text>
<rect x="507" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="516" y="60" class="bit-text">1</text>
<rect x="530" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="539" y="60" class="bit-text">0</text>
<rect x="553" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="562" y="60" class="bit-text">1</text>
<rect x="576" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="585" y="60" class="bit-text">1</text>
<rect x="599" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="608" y="60" class="bit-text">1</text>
<rect x="622" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="631" y="60" class="bit-text">1</text>
<rect x="645" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="654" y="60" class="bit-text">0</text>
<rect x="668" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="677" y="60" class="bit-text">1</text>
<rect x="691" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="700" y="60" class="bit-text">0</text>
<rect x="714" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="723" y="60" class="bit-text">1</text>
<rect x="737" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="746" y="60" class="bit-text">0</text>
<rect x="760" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="769" y="60" class="bit-text">0</text>
<rect x="783" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="792" y="60" class="bit-text">0</text>
<text x="820" y="60" class="value-text">= 0.3952</text>
<!-- BF16 Format (16 bits: 1 + 8 + 7) -->
<text x="30" y="120" class="format-label">BF16</text>
<!-- Sign bit (1) -->
<rect x="70" y="105" width="18" height="30" class="sign-bit"/>
<text x="79" y="120" class="bit-text">0</text>
<!-- Exponent bits (8) -->
<rect x="93" y="105" width="18" height="30" class="exponent-bit"/>
<text x="102" y="120" class="bit-text">0</text>
<rect x="116" y="105" width="18" height="30" class="exponent-bit"/>
<text x="125" y="120" class="bit-text">1</text>
<rect x="139" y="105" width="18" height="30" class="exponent-bit"/>
<text x="148" y="120" class="bit-text">1</text>
<rect x="162" y="105" width="18" height="30" class="exponent-bit"/>
<text x="171" y="120" class="bit-text">1</text>
<rect x="185" y="105" width="18" height="30" class="exponent-bit"/>
<text x="194" y="120" class="bit-text">1</text>
<rect x="208" y="105" width="18" height="30" class="exponent-bit"/>
<text x="217" y="120" class="bit-text">1</text>
<rect x="231" y="105" width="18" height="30" class="exponent-bit"/>
<text x="240" y="120" class="bit-text">0</text>
<rect x="254" y="105" width="18" height="30" class="exponent-bit"/>
<text x="263" y="120" class="bit-text">1</text>
<!-- Mantissa bits (7) -->
<rect x="277" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="286" y="120" class="bit-text">1</text>
<rect x="300" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="309" y="120" class="bit-text">0</text>
<rect x="323" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="332" y="120" class="bit-text">0</text>
<rect x="346" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="355" y="120" class="bit-text">1</text>
<rect x="369" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="378" y="120" class="bit-text">0</text>
<rect x="392" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="401" y="120" class="bit-text">1</text>
<rect x="415" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="424" y="120" class="bit-text">0</text>
<text x="820" y="120" class="value-text">≈ 0.3945</text>
<!-- FP16 Format (16 bits: 1 + 5 + 10) -->
<text x="30" y="180" class="format-label">FP16</text>
<!-- Sign bit (1) -->
<rect x="70" y="165" width="18" height="30" class="sign-bit"/>
<text x="79" y="180" class="bit-text">0</text>
<!-- Exponent bits (5) -->
<rect x="93" y="165" width="18" height="30" class="exponent-bit"/>
<text x="102" y="180" class="bit-text">0</text>
<rect x="116" y="165" width="18" height="30" class="exponent-bit"/>
<text x="125" y="180" class="bit-text">1</text>
<rect x="139" y="165" width="18" height="30" class="exponent-bit"/>
<text x="148" y="180" class="bit-text">1</text>
<rect x="162" y="165" width="18" height="30" class="exponent-bit"/>
<text x="171" y="180" class="bit-text">0</text>
<rect x="185" y="165" width="18" height="30" class="exponent-bit"/>
<text x="194" y="180" class="bit-text">1</text>
<!-- Mantissa bits (10) -->
<rect x="208" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="217" y="180" class="bit-text">1</text>
<rect x="231" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="240" y="180" class="bit-text">0</text>
<rect x="254" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="263" y="180" class="bit-text">0</text>
<rect x="277" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="286" y="180" class="bit-text">1</text>
<rect x="300" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="309" y="180" class="bit-text">0</text>
<rect x="323" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="332" y="180" class="bit-text">1</text>
<rect x="346" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="355" y="180" class="bit-text">0</text>
<rect x="369" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="378" y="180" class="bit-text">0</text>
<rect x="392" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="401" y="180" class="bit-text">1</text>
<rect x="415" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="424" y="180" class="bit-text">0</text>
<text x="820" y="180" class="value-text">≈ 0.3950</text>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1050 580" width="1050" height="580">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.arrow { stroke: #616161; stroke-width: 2; fill: none; marker-end: url(#arrowhead); }
.column-title { font-family: 'Segoe UI', Arial, sans-serif; font-size: 14px; font-weight: 600; text-anchor: middle; fill: #424242; }
.divider { stroke: #bdbdbd; stroke-width: 1.5; stroke-dasharray: 8,6; }
</style>
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- Title -->
<text x="525" y="30" class="title">Master Weights Storage Approaches</text>
<!-- Vertical dividers (dashed lines) -->
<line x1="350" y1="50" x2="350" y2="560" class="divider"/>
<line x1="700" y1="50" x2="700" y2="560" class="divider"/>
<!-- Column 1: Low Precision Only -->
<text x="175" y="75" class="column-title">Low Precision Weights</text>
<text x="175" y="93" class="small-text">(no master weights)</text>
<!-- Model box -->
<rect x="60" y="145" width="230" height="90" rx="6" fill="#f5f5f5" stroke="#9e9e9e" stroke-width="1.5"/>
<text x="175" y="168" class="label">Model</text>
<rect x="80" y="183" width="190" height="40" class="hp" rx="4"/>
<text x="175" y="208" class="text">Weights (BF16/FP16)</text>
<!-- Arrow down -->
<path d="M 175 235 L 175 300" class="arrow"/>
<!-- Computation -->
<rect x="90" y="300" width="170" height="50" class="gemm" rx="6"/>
<text x="175" y="330" class="text">Forward/Backward</text>
<!-- Arrow down -->
<path d="M 175 350 L 175 415" class="arrow"/>
<!-- Optimizer box -->
<rect x="60" y="415" width="230" height="90" rx="6" fill="#f5f5f5" stroke="#9e9e9e" stroke-width="1.5"/>
<text x="175" y="438" class="label">Optimizer</text>
<rect x="80" y="453" width="190" height="40" class="fp32" rx="4"/>
<text x="175" y="478" class="text">State (FP32)</text>
<!-- Column 2: Master Weights in Model -->
<text x="525" y="75" class="column-title">Master Weights in Model</text>
<!-- Model box -->
<rect x="410" y="145" width="230" height="90" rx="6" fill="#f5f5f5" stroke="#9e9e9e" stroke-width="1.5"/>
<text x="525" y="168" class="label">Model</text>
<rect x="430" y="183" width="190" height="40" class="fp32" rx="4"/>
<text x="525" y="208" class="text">Weights (FP32)</text>
<!-- Arrow down with cast -->
<path d="M 525 235 L 525 300" class="arrow"/>
<rect x="465" y="255" width="120" height="26" class="quantize" rx="4"/>
<text x="525" y="273" class="small-text">cast to BF16/FP16</text>
<!-- Computation -->
<rect x="440" y="300" width="170" height="50" class="gemm" rx="6"/>
<text x="525" y="330" class="text">Forward/Backward</text>
<!-- Arrow down -->
<path d="M 525 350 L 525 415" class="arrow"/>
<!-- Optimizer box -->
<rect x="410" y="415" width="230" height="90" rx="6" fill="#f5f5f5" stroke="#9e9e9e" stroke-width="1.5"/>
<text x="525" y="438" class="label">Optimizer</text>
<rect x="430" y="453" width="190" height="40" class="fp32" rx="4"/>
<text x="525" y="478" class="text">State (FP32)</text>
<!-- Column 3: Master Weights in Optimizer -->
<text x="875" y="75" class="column-title">Master Weights in Optimizer</text>
<!-- Cast box above Model -->
<rect x="815" y="105" width="120" height="26" class="quantize" rx="4"/>
<text x="875" y="123" class="small-text">cast to BF16/FP16</text>
<!-- Arrow from cast to Model -->
<path d="M 875 131 L 875 145" class="arrow"/>
<!-- Model box -->
<rect x="760" y="145" width="230" height="90" rx="6" fill="#f5f5f5" stroke="#9e9e9e" stroke-width="1.5"/>
<text x="875" y="168" class="label">Model</text>
<rect x="780" y="183" width="190" height="40" class="hp" rx="4"/>
<text x="875" y="208" class="text">Weights (BF16/FP16)</text>
<!-- Arrow down -->
<path d="M 875 235 L 875 300" class="arrow"/>
<!-- Computation -->
<rect x="790" y="300" width="170" height="50" class="gemm" rx="6"/>
<text x="875" y="330" class="text">Forward/Backward</text>
<!-- Arrow down -->
<path d="M 875 350 L 875 415" class="arrow"/>
<!-- Optimizer box with State and Master -->
<rect x="760" y="415" width="230" height="140" rx="6" fill="#f5f5f5" stroke="#9e9e9e" stroke-width="1.5"/>
<text x="875" y="438" class="label">Optimizer</text>
<rect x="780" y="453" width="190" height="40" class="fp32" rx="4"/>
<text x="875" y="478" class="text">State (FP32)</text>
<rect x="780" y="503" width="190" height="40" class="fp32" rx="4"/>
<text x="875" y="528" class="text">Master (FP32)</text>
<!-- Arrow from Master to cast -->
<path d="M 970 523 L 1010 523 L 1010 118 L 935 118" class="arrow"/>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 850 480" width="850" height="480" style="display: block; margin: 0 auto;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.arrow { stroke: #616161; stroke-width: 1.5; fill: none; marker-end: url(#arrowhead); }
.section-label { font-family: 'Segoe UI', Arial, sans-serif; font-size: 16px; font-weight: 600; fill: #424242; text-anchor: start; }
</style>
<marker id="arrowhead" markerWidth="3" markerHeight="3" refX="3" refY="1.5" orient="auto">
<polygon points="0 0, 3 1.5, 0 3" fill="#616161" />
</marker>
</defs>
<!-- Title -->
<text x="425" y="30" class="title" style="text-anchor: middle;">Transformer Layer – default precision of operation in low precision recipe</text>
<!-- Row 1: Input → Layer Norm → QKV Linear → QK^T → Softmax -->
<rect x="20" y="60" width="115" height="50" rx="5" class="hp"/>
<text x="77" y="90" class="text">Input</text>
<path d="M 135 85 L 158 85" class="arrow"/>
<rect x="158" y="60" width="115" height="50" rx="5" class="hp"/>
<text x="215" y="90" class="text">Layer Norm</text>
<path d="M 273 85 L 296 85" class="arrow"/>
<rect x="296" y="60" width="115" height="50" rx="5" class="gemm"/>
<text x="353" y="90" class="text">QKV Linear</text>
<path d="M 411 85 L 434 85" class="arrow"/>
<rect x="434" y="60" width="115" height="50" rx="5" class="hp"/>
<text x="491" y="90" class="text">QK^T</text>
<path d="M 549 85 L 572 85" class="arrow"/>
<rect x="572" y="60" width="115" height="50" rx="5" class="hp"/>
<text x="629" y="90" class="text">Softmax</text>
<!-- Row 2: Attn * V → Output Linear → Dropout + Add -->
<path d="M 629 110 L 629 145" class="arrow"/>
<rect x="572" y="145" width="115" height="50" rx="5" class="hp"/>
<text x="629" y="175" class="text">Scores * V</text>
<path d="M 572 170 L 549 170" class="arrow"/>
<rect x="434" y="145" width="115" height="50" rx="5" class="gemm"/>
<text x="491" y="175" class="text">Output Linear</text>
<path d="M 434 170 L 273 170" class="arrow"/>
<rect x="158" y="145" width="115" height="50" rx="5" class="hp"/>
<text x="215" y="175" class="text">Dropout + Add</text>
<!-- Row 3: Layer Norm → FFN Linear 1 → GELU → FFN Linear 2 → Output -->
<path d="M 215 195 L 215 230" class="arrow"/>
<rect x="158" y="230" width="115" height="50" rx="5" class="hp"/>
<text x="215" y="260" class="text">Layer Norm</text>
<path d="M 273 255 L 296 255" class="arrow"/>
<rect x="296" y="230" width="115" height="50" rx="5" class="gemm"/>
<text x="353" y="260" class="text">FFN Linear 1</text>
<path d="M 411 255 L 434 255" class="arrow"/>
<rect x="434" y="230" width="115" height="50" rx="5" class="hp"/>
<text x="491" y="260" class="text">GELU</text>
<path d="M 549 255 L 572 255" class="arrow"/>
<rect x="572" y="230" width="115" height="50" rx="5" class="gemm"/>
<text x="629" y="260" class="text">FFN Linear 2</text>
<path d="M 687 255 L 710 255" class="arrow"/>
<rect x="710" y="230" width="115" height="50" rx="5" class="hp"/>
<text x="767" y="260" class="text">Output</text>
<!-- Memory State Section -->
<text x="20" y="325" class="section-label">Memory State:</text>
<!-- Parameters -->
<rect x="20" y="340" width="180" height="45" rx="5" class="hp"/>
<text x="110" y="365" class="text">Parameters</text>
<!-- Gradients -->
<rect x="225" y="340" width="140" height="45" rx="5" class="hp"/>
<text x="295" y="365" class="text">Gradients</text>
<!-- Legend -->
<g transform="translate(20, 415)">
<!-- High Precision -->
<rect x="0" y="0" width="80" height="40" rx="5" class="hp"/>
<text x="95" y="23" class="text" style="text-anchor: start;">Higher Precision (FP32/BF16/FP16)</text>
<!-- Low Precision -->
<rect x="400" y="0" width="80" height="40" rx="5" class="gemm"/>
<text x="495" y="23" class="text" style="text-anchor: start;">Lower Precision (FP8, MXFP8 etc.)</text>
</g>
</svg>
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Introduction
===================================
Transformer Engine accelerates deep learning on NVIDIA GPUs in several ways,
with low precision training being one of the most important.
This chapter introduces mixed precision training and FP8 support.
Training in BF16/FP16
---------------------
Deep learning traditionally uses 32-bit floating-point (FP32) numbers.
NVIDIA GPUs support lower precision formats—FP16 since Pascal, BF16 since Ampere—which offer higher throughput and lower memory usage.
Let's compare these formats.
.. raw:: html
:file: img/fp_formats_comparison.svg
*Figure 1: Comparison of FP32, BF16, and FP16 floating-point formats showing bit allocation for sign, exponent, and mantissa.*
The key differences between these formats are:
* **FP32** (32 bits total): 1 sign bit + 8 exponent bits + 23 mantissa bits – standard single-precision format
* **BF16** (16 bits total): 1 sign bit + 8 exponent bits + 7 mantissa bits – maintains FP32's exponent range but has reduced precision
* **FP16** (16 bits total): 1 sign bit + 5 exponent bits + 10 mantissa bits – reduced range but higher precision than BF16
BF16's advantage is that it shares the same exponent range as FP32,
making it easier to convert between the two formats without overflow/underflow issues.
FP16 offers better precision for smaller values but has a limited dynamic range,
which results in the need to perform loss scaling to avoid overflow/underflow—see `this paper on loss scaling <https://arxiv.org/pdf/1710.03740>`__ for more details.
**Mixed precision**
Not all operations should be run in reduced precision to preserve accuracy.
Modern deep learning frameworks use *mixed precision training*,
where different operations use different precisions based on their numerical properties:
* Matrix multiplications are compute-heavy and remain numerically stable at lower precision, making them ideal candidates for acceleration.
* Operations like layer normalization and softmax can work with low precision inputs and outputs, but may use high precision internally or for their weights.
* Operations like loss computation and exponentiation need high precision throughout.
**Master weights**
Another consideration in mixed precision training is how to store the model weights.
Lower precision formats like FP16 and BF16 have limited representational granularity,
which becomes problematic during gradient updates.
When a small gradient is added to a not so small weight stored in low precision,
the result may round back to the original value if the update falls below the format's precision threshold.
Moreover, some elements of the gradient itself can be too small to be represented in low precision,
especially after the accumulation from multiple GPUs in the data parallel training setting.
The solution is to maintain *master weights* in FP32.
During training, weights are cast to lower precision for forward and backward passes,
but the gradient updates are applied to the full-precision master copy.
This ensures that even small gradients accumulate correctly over time.
There are two common software approaches to storing master weights:
* *In the optimizer*:
The model holds low-precision weights,
while the optimizer maintains FP32 copies alongside momentum and other state.
During each step,
the optimizer updates its FP32 copy and casts the result back to the model's low-precision weights.
This approach makes it easier to shard master weights together with other optimizer state, for example in ZeRO optimizer.
Since the casting happens only during the optimizer step, this approach is also faster when optimizer runs less frequently than the model, e.g. when performing gradient accumulation or pipeline parallel training.
* *In the model*:
The model stores weights directly in FP32,
and they are cast to lower precision on-the-fly during forward and backward passes.
This approach works seamlessly with any standard optimizer, requiring no special support.
.. raw:: html
:file: img/master_weights_approaches.svg
*Figure 2: Three approaches to weight storage—low precision only (no master weights), master weights stored in the model, and master weights stored in the optimizer.*
.. tabs::
.. tab:: PyTorch
The PyTorch API of Transformer Engine provides several mechanisms to control precision:
* **Weight precision**: Use the ``params_dtype`` argument in any TE layer constructor.
* **Computation precision**: Use the ``torch.autocast`` context manager. When enabled, inputs are cast to the autocast dtype before computation.
* **Input dtype**: When ``torch.autocast`` is not used, the input tensor's dtype determines the computation precision. In this case, inputs and parameters must have matching dtypes.
.. literalinclude:: bf16_fp16_training_pytorch.py
:language: python
:start-after: # START_BF16_FP16_TRAINING
:end-before: # END_BF16_FP16_TRAINING
.. tab:: JAX
The JAX API of Transformer Engine provides two mechanisms to control precision:
* **Weight precision**: Use the ``dtype`` argument in any TE layer constructor.
* **Computation precision**: Determined by the dtype of the input tensor.
For training with master weights in FP32 and computation in BF16,
cast the input tensor to BF16 before passing it to the layer.
.. literalinclude:: bf16_fp16_training_jax.py
:language: python
:start-after: # START_BF16_FP16_TRAINING
:end-before: # END_BF16_FP16_TRAINING
Lower precisions
----------------
Transformer Engine's primary feature is supporting even lower precision than BF16/FP16, such as FP8, MXFP8, NVFP4, etc.
The logic of these precisions is more complicated than the logic of BF16/FP16 – they require scaling factors to
properly represent the full range of values in the tensor. Sometimes it is one scaling factor per tensor,
sometimes it is one scaling factor per block of values. A precision format combined with the logic for training
is called **a recipe**.
In this section we present common logic for all the recipes. Each one of them is described in more detail in a separate section later.
Let's now see how we can train in lower precisions in supported frameworks.
.. tabs::
.. tab:: PyTorch
The PyTorch API of Transformer Engine provides an ``autocast`` context manager to control precision.
It's similar to the ``torch.autocast`` context manager, but tailored for low precision training.
The most important argument is the ``recipe`` argument, which accepts objects inheriting from
:class:`~transformer_engine.common.recipe.Recipe`.
Forward computations need to be performed inside the ``autocast`` context manager,
while the ``.backward()`` call should be outside of it (it inherits the setting from the
corresponding forward pass).
Here is a basic example:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_pytorch.py
:language: python
:start-after: # START_AUTOCAST_BASIC
:end-before: # END_AUTOCAST_BASIC
You can use multiple recipes in the same model in the following ways:
**Sequential contexts** – apply different recipes to different parts of your model:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_pytorch.py
:language: python
:start-after: # START_AUTOCAST_SEQUENTIAL
:end-before: # END_AUTOCAST_SEQUENTIAL
**Nested contexts** – the inner context overrides the outer one for its scope:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_pytorch.py
:language: python
:start-after: # START_AUTOCAST_NESTED
:end-before: # END_AUTOCAST_NESTED
.. tab:: JAX
The JAX API of Transformer Engine provides an ``autocast`` context manager similar to PyTorch.
The key difference is that in JAX, model initialization must happen inside the ``autocast`` context
to properly capture quantization metadata in the parameter tree.
Here is a basic example:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_jax.py
:language: python
:start-after: # START_AUTOCAST_BASIC
:end-before: # END_AUTOCAST_BASIC
You can use multiple recipes in the same model in the following ways:
**Sequential contexts** – apply different recipes to different parts of your model:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_jax.py
:language: python
:start-after: # START_AUTOCAST_SEQUENTIAL
:end-before: # END_AUTOCAST_SEQUENTIAL
**Nested contexts** – the inner context overrides the outer one for its scope:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_jax.py
:language: python
:start-after: # START_AUTOCAST_NESTED
:end-before: # END_AUTOCAST_NESTED
.. note::
Python context managers like ``autocast`` may interact unexpectedly with JAX's JIT compilation.
For finer-grained control, consider passing the recipe directly to TE modules instead.
See the `TE JAX Integration notebook <https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_jax_integration.ipynb>`_
for details.
**Mixed precision with 8- or 4-bit precisions**
From now on, we will refer to FP8/MXFP8/NVFP4 etc. as *low precision*
and to FP32/BF16/FP16 as *high precision*. This terminology will be
used throughout the rest of the documentation.
Not all operations run in low precision:
- **Linear operations**: run in low precision.
- **Attention computations**: run in high precision by default (some recipes allow low precision as an option).
- **Other operations** (layer normalization, softmax, etc.): run in high precision.
Within high-precision operations, there are two categories:
- **Configurable precision**: most operations run in parameter precision (FP32/BF16/FP16) or the precision specified by ``torch.autocast``.
- **Fixed FP32 precision**: some operations, or parts of operations—such as the division in layernorm—always run in FP32, regardless of other settings.
.. raw:: html
:file: img/mixed_precision_operations.svg
*Figure 3: Default precision of operations in a TransformerLayer forward pass. Only linear operations are in lower precision. Dot product attention is shown as three separate operations (QK^T, Softmax, Scores * V), though in practice these may be fused into a single kernel.*
**Linear layer data flow**
Let's see how data flow of a linear layer works by default on a single H100 GPU with FP8 precision:
H100 (Hopper) architecture natively supports FP8 Matrix Multiplication only in **TN** layout (Transpose-NoTranspose),
so GEMM with tensors ``A`` and ``B`` returns ``B * A^T``.
*Forward pass*
* Input is quantized to FP8 – both ``input`` and ``input^T`` quantized versions are created.
* Weights are stored in high precision and quantized to low precision before the GEMM – both ``weight`` and ``weight^T`` quantized versions are created.
* FP8 GEMM with layout **TN** is run with ``weight`` and ``input`` tensors,
* Outputs – ``input * weight^T`` tensor – are returned in high precision.
*Backward pass*
* Output gradients are quantized to FP8 – both ``output_grad`` and ``output_grad^T`` quantized versions are created.
* FP8 GEMM with layout **TN** is performed with ``weight^T`` and ``output_grad`` tensors to compute input gradients.
* FP8 GEMM with layout **TN** is performed with ``input^T`` and ``output_grad^T`` tensors to compute weight gradients.
* Input gradients – ``output_grad * weight`` tensor – are returned in high precision.
* Weight gradients – ``output_grad^T * input`` tensor – are returned in high precision.
.. raw:: html
:file: img/fp8_linear_flow.svg
*Figure 4: Forward pass of a Linear layer with low precision data flow.*
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 60 500 380" width="100%" style="max-width: 500px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 18px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 14px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 12px sans-serif; fill: #555; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
/* Tensor colors */
.fp8-tensor { fill: #87CEEB; stroke: #444; stroke-width: 2; }
.fp8-block { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.fp8-block-alt { fill: #5F9FCC; stroke: #555; stroke-width: 1.5; }
/* Scaling factor colors */
.scale-factor { fill: #FFA500; stroke: #444; stroke-width: 2; }
.grid-line { stroke: #444; stroke-width: 2; }
.boundary-line { stroke: #444; stroke-width: 2; }
</style>
</defs>
<!-- MXFP8 Scaling -->
<g id="mxfp8-scaling">
<text x="250" y="85" class="title">MXFP8</text>
<text x="250" y="108" class="label" style="font-size: 12px;">(One scaling factor per 32 elements)</text>
<!-- FP8 Tensor split into many small blocks (40×10) -->
<g id="tensor-blocks">
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="250.0" y="140.0" width="40" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="130.0" y="190.0" width="240" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Blocks ONLY where they don't overlap with white cross -->
<rect x="130" y="140" width="40" height="10" class="fp8-block"/>
<rect x="170" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="140" width="40" height="10" class="fp8-block"/>
<rect x="290" y="140" width="40" height="10" class="fp8-block"/>
<rect x="330" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="150" width="40" height="10" class="fp8-block"/>
<rect x="210" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="150" width="40" height="10" class="fp8-block"/>
<rect x="130" y="160" width="40" height="10" class="fp8-block"/>
<rect x="170" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="160" width="40" height="10" class="fp8-block"/>
<rect x="290" y="160" width="40" height="10" class="fp8-block"/>
<rect x="330" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="170" width="40" height="10" class="fp8-block"/>
<rect x="210" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="170" width="40" height="10" class="fp8-block"/>
<rect x="130" y="180" width="40" height="10" class="fp8-block"/>
<rect x="170" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="180" width="40" height="10" class="fp8-block"/>
<rect x="290" y="180" width="40" height="10" class="fp8-block"/>
<rect x="330" y="180" width="40" height="10" class="fp8-block-alt"/>
<!-- Three dots in VERTICAL white bar -->
<text x="270" y="167.5" class="dots-text"></text>
<text x="270" y="242.5" class="dots-text"></text>
<!-- Three dots in HORIZONTAL white bar -->
<text x="190" y="205" class="dots-text"></text>
<text x="330" y="205" class="dots-text"></text>
<!-- ONE diagonal dot at intersection -->
<text x="270" y="205" class="dots-text" transform="rotate(45 270 205)"></text>
<!-- Bottom rows (y >= 220 after horizontal white bar) -->
<rect x="130" y="220" width="40" height="10" class="fp8-block"/>
<rect x="170" y="220" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="220" width="40" height="10" class="fp8-block"/>
<rect x="290" y="220" width="40" height="10" class="fp8-block"/>
<rect x="330" y="220" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="230" width="40" height="10" class="fp8-block"/>
<rect x="210" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="230" width="40" height="10" class="fp8-block"/>
<rect x="130" y="240" width="40" height="10" class="fp8-block"/>
<rect x="170" y="240" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="240" width="40" height="10" class="fp8-block"/>
<rect x="290" y="240" width="40" height="10" class="fp8-block"/>
<rect x="330" y="240" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="250" width="40" height="10" class="fp8-block"/>
<rect x="210" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="250" width="40" height="10" class="fp8-block"/>
<!-- Main outline -->
<rect x="130.0" y="140.0" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Scaling factors tensor - 3+2 columns of 10px squares -->
<g id="scale-factors">
<!-- Orange background -->
<rect x="215" y="285" width="70" height="120" fill="#FFA500"/>
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="245" y="285" width="20" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="215" y="335" width="70" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Grid lines showing 10x10 squares (3 left + 2 right columns) -->
<!-- Vertical lines every 10px (skipping white space) -->
<!-- Left 3 columns -->
<line x1="225" y1="285" x2="225" y2="335" class="grid-line" stroke-width="1"/>
<line x1="235" y1="285" x2="235" y2="335" class="grid-line" stroke-width="1"/>
<line x1="245" y1="285" x2="245" y2="335" class="grid-line" stroke-width="1"/>
<!-- Right 2 columns -->
<line x1="265" y1="285" x2="265" y2="335" class="grid-line" stroke-width="1"/>
<line x1="275" y1="285" x2="275" y2="335" class="grid-line" stroke-width="1"/>
<line x1="285" y1="285" x2="285" y2="335" class="grid-line" stroke-width="1"/>
<!-- Bottom sections -->
<line x1="225" y1="365" x2="225" y2="405" class="grid-line" stroke-width="1"/>
<line x1="235" y1="365" x2="235" y2="405" class="grid-line" stroke-width="1"/>
<line x1="245" y1="365" x2="245" y2="405" class="grid-line" stroke-width="1"/>
<line x1="265" y1="365" x2="265" y2="405" class="grid-line" stroke-width="1"/>
<line x1="275" y1="365" x2="275" y2="405" class="grid-line" stroke-width="1"/>
<line x1="285" y1="365" x2="285" y2="405" class="grid-line" stroke-width="1"/>
<!-- Horizontal lines every 10px -->
<line x1="215" y1="295" x2="245" y2="295" class="grid-line" stroke-width="1"/>
<line x1="265" y1="295" x2="285" y2="295" class="grid-line" stroke-width="1"/>
<line x1="215" y1="305" x2="245" y2="305" class="grid-line" stroke-width="1"/>
<line x1="265" y1="305" x2="285" y2="305" class="grid-line" stroke-width="1"/>
<line x1="215" y1="315" x2="245" y2="315" class="grid-line" stroke-width="1"/>
<line x1="265" y1="315" x2="285" y2="315" class="grid-line" stroke-width="1"/>
<line x1="215" y1="325" x2="245" y2="325" class="grid-line" stroke-width="1"/>
<line x1="265" y1="325" x2="285" y2="325" class="grid-line" stroke-width="1"/>
<!-- Top bottom boundaries -->
<line x1="215" y1="335" x2="245" y2="335" class="grid-line" stroke-width="1"/>
<line x1="265" y1="335" x2="285" y2="335" class="grid-line" stroke-width="1"/>
<line x1="215" y1="365" x2="245" y2="365" class="grid-line" stroke-width="1"/>
<line x1="265" y1="365" x2="285" y2="365" class="grid-line" stroke-width="1"/>
<line x1="215" y1="375" x2="245" y2="375" class="grid-line" stroke-width="1"/>
<line x1="265" y1="375" x2="285" y2="375" class="grid-line" stroke-width="1"/>
<line x1="215" y1="385" x2="245" y2="385" class="grid-line" stroke-width="1"/>
<line x1="265" y1="385" x2="285" y2="385" class="grid-line" stroke-width="1"/>
<line x1="215" y1="395" x2="245" y2="395" class="grid-line" stroke-width="1"/>
<line x1="265" y1="395" x2="285" y2="395" class="grid-line" stroke-width="1"/>
<!-- Bottom boundaries -->
<line x1="215" y1="405" x2="245" y2="405" class="grid-line" stroke-width="1"/>
<line x1="265" y1="405" x2="285" y2="405" class="grid-line" stroke-width="1"/>
<!-- Main outline -->
<rect x="215" y="285" width="70" height="120" fill="none" stroke="#444" stroke-width="2"/>
<!-- Three dots -->
<text x="255" y="312.5" class="dots-text" style="font-size: 14px;"></text>
<text x="255" y="387.5" class="dots-text" style="font-size: 14px;"></text>
<text x="230" y="350" class="dots-text" style="font-size: 14px;"></text>
<text x="275" y="350" class="dots-text" style="font-size: 14px;"></text>
<text x="255" y="350" class="dots-text" style="font-size: 14px;" transform="rotate(45 255 350)"></text>
</g>
<text x="250" y="430" class="small-text" text-anchor="middle">E8M0 scaling factors (one per 32 elements)</text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 650 450" width="100%" style="max-width: 650px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 18px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 14px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 12px sans-serif; fill: #555; text-anchor: middle; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
.fp8-block { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.fp8-block-alt { fill: #5F9FCC; stroke: #555; stroke-width: 1.5; }
.scale-block { fill: #FFA500; stroke: #555; stroke-width: 1.5; }
</style>
</defs>
<!-- ROWWISE SECTION -->
<text x="325" y="30" class="title">Rowwise (1x32 blocks)</text>
<!-- Rowwise Data Tensor -->
<g id="rowwise-tensor">
<text x="160" y="55" class="small-text">Data</text>
<rect x="40" y="70" width="40" height="10" class="fp8-block"/>
<rect x="80" y="70" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="70" width="40" height="10" class="fp8-block"/>
<rect x="200" y="70" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="70" width="40" height="10" class="fp8-block"/>
<rect x="40" y="80" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="80" width="40" height="10" class="fp8-block"/>
<rect x="120" y="80" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="80" width="40" height="10" class="fp8-block"/>
<rect x="240" y="80" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="90" width="40" height="10" class="fp8-block"/>
<rect x="80" y="90" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="90" width="40" height="10" class="fp8-block"/>
<rect x="200" y="90" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="90" width="40" height="10" class="fp8-block"/>
<rect x="40" y="100" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="100" width="40" height="10" class="fp8-block"/>
<rect x="120" y="100" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="100" width="40" height="10" class="fp8-block"/>
<rect x="240" y="100" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="110" width="40" height="10" class="fp8-block"/>
<rect x="80" y="110" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="110" width="40" height="10" class="fp8-block"/>
<rect x="200" y="110" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="110" width="40" height="10" class="fp8-block"/>
<rect x="40" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="150" width="40" height="10" class="fp8-block"/>
<rect x="120" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="150" width="40" height="10" class="fp8-block"/>
<rect x="240" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="160" width="40" height="10" class="fp8-block"/>
<rect x="80" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="160" width="40" height="10" class="fp8-block"/>
<rect x="200" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="160" width="40" height="10" class="fp8-block"/>
<rect x="40" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="170" width="40" height="10" class="fp8-block"/>
<rect x="120" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="170" width="40" height="10" class="fp8-block"/>
<rect x="240" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="180" width="40" height="10" class="fp8-block"/>
<rect x="80" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="180" width="40" height="10" class="fp8-block"/>
<rect x="200" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="180" width="40" height="10" class="fp8-block"/>
<rect x="40" y="70" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
<text x="180" y="97.5" class="dots-text"></text>
<text x="180" y="172.5" class="dots-text"></text>
<text x="100" y="135" class="dots-text"></text>
<text x="240" y="135" class="dots-text"></text>
<text x="180" y="135" class="dots-text" transform="rotate(45 180 135)"></text>
</g>
<!-- Rowwise Scale Tensor -->
<g id="rowwise-scales">
<text x="485" y="55" class="small-text">Scales</text>
<!-- Rows 1-5 -->
<rect x="380" y="70" width="10" height="10" class="scale-block"/>
<rect x="390" y="70" width="10" height="10" class="scale-block"/>
<rect x="400" y="70" width="10" height="10" class="scale-block"/>
<rect x="450" y="70" width="10" height="10" class="scale-block"/>
<rect x="460" y="70" width="10" height="10" class="scale-block"/>
<rect x="380" y="80" width="10" height="10" class="scale-block"/>
<rect x="390" y="80" width="10" height="10" class="scale-block"/>
<rect x="400" y="80" width="10" height="10" class="scale-block"/>
<rect x="450" y="80" width="10" height="10" class="scale-block"/>
<rect x="460" y="80" width="10" height="10" class="scale-block"/>
<rect x="380" y="90" width="10" height="10" class="scale-block"/>
<rect x="390" y="90" width="10" height="10" class="scale-block"/>
<rect x="400" y="90" width="10" height="10" class="scale-block"/>
<rect x="450" y="90" width="10" height="10" class="scale-block"/>
<rect x="460" y="90" width="10" height="10" class="scale-block"/>
<rect x="380" y="100" width="10" height="10" class="scale-block"/>
<rect x="390" y="100" width="10" height="10" class="scale-block"/>
<rect x="400" y="100" width="10" height="10" class="scale-block"/>
<rect x="450" y="100" width="10" height="10" class="scale-block"/>
<rect x="460" y="100" width="10" height="10" class="scale-block"/>
<rect x="380" y="110" width="10" height="10" class="scale-block"/>
<rect x="390" y="110" width="10" height="10" class="scale-block"/>
<rect x="400" y="110" width="10" height="10" class="scale-block"/>
<rect x="450" y="110" width="10" height="10" class="scale-block"/>
<rect x="460" y="110" width="10" height="10" class="scale-block"/>
<!-- Gap rows -->
<rect x="380" y="150" width="10" height="10" class="scale-block"/>
<rect x="390" y="150" width="10" height="10" class="scale-block"/>
<rect x="400" y="150" width="10" height="10" class="scale-block"/>
<rect x="450" y="150" width="10" height="10" class="scale-block"/>
<rect x="460" y="150" width="10" height="10" class="scale-block"/>
<rect x="380" y="160" width="10" height="10" class="scale-block"/>
<rect x="390" y="160" width="10" height="10" class="scale-block"/>
<rect x="400" y="160" width="10" height="10" class="scale-block"/>
<rect x="450" y="160" width="10" height="10" class="scale-block"/>
<rect x="460" y="160" width="10" height="10" class="scale-block"/>
<rect x="380" y="170" width="10" height="10" class="scale-block"/>
<rect x="390" y="170" width="10" height="10" class="scale-block"/>
<rect x="400" y="170" width="10" height="10" class="scale-block"/>
<rect x="450" y="170" width="10" height="10" class="scale-block"/>
<rect x="460" y="170" width="10" height="10" class="scale-block"/>
<rect x="380" y="180" width="10" height="10" class="scale-block"/>
<rect x="390" y="180" width="10" height="10" class="scale-block"/>
<rect x="400" y="180" width="10" height="10" class="scale-block"/>
<rect x="450" y="180" width="10" height="10" class="scale-block"/>
<rect x="460" y="180" width="10" height="10" class="scale-block"/>
<rect x="380" y="70" width="90" height="120" fill="none" stroke="#444" stroke-width="2"/>
<text x="430" y="97.5" class="dots-text"></text>
<text x="430" y="172.5" class="dots-text"></text>
<text x="400" y="135" class="dots-text"></text>
<text x="460" y="135" class="dots-text"></text>
<text x="430" y="135" class="dots-text" transform="rotate(45 430 135)"></text>
</g>
<!-- COLUMNWISE SECTION -->
<text x="325" y="230" class="title">Columnwise (32x1 blocks)</text>
<!-- Columnwise Data Tensor -->
<g id="colwise-tensor">
<text x="160" y="255" class="small-text">Data</text>
<rect x="40" y="270" width="10" height="40" class="fp8-block"/>
<rect x="50" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="60" y="270" width="10" height="40" class="fp8-block"/>
<rect x="70" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="80" y="270" width="10" height="40" class="fp8-block"/>
<rect x="90" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="100" y="270" width="10" height="40" class="fp8-block"/>
<rect x="110" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="120" y="270" width="10" height="40" class="fp8-block"/>
<rect x="130" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="180" y="270" width="10" height="40" class="fp8-block"/>
<rect x="190" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="200" y="270" width="10" height="40" class="fp8-block"/>
<rect x="210" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="220" y="270" width="10" height="40" class="fp8-block"/>
<rect x="230" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="240" y="270" width="10" height="40" class="fp8-block"/>
<rect x="250" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="260" y="270" width="10" height="40" class="fp8-block"/>
<rect x="270" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="40" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="50" y="350" width="10" height="40" class="fp8-block"/>
<rect x="60" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="70" y="350" width="10" height="40" class="fp8-block"/>
<rect x="80" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="90" y="350" width="10" height="40" class="fp8-block"/>
<rect x="100" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="110" y="350" width="10" height="40" class="fp8-block"/>
<rect x="120" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="130" y="350" width="10" height="40" class="fp8-block"/>
<rect x="180" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="190" y="350" width="10" height="40" class="fp8-block"/>
<rect x="200" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="210" y="350" width="10" height="40" class="fp8-block"/>
<rect x="220" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="230" y="350" width="10" height="40" class="fp8-block"/>
<rect x="240" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="250" y="350" width="10" height="40" class="fp8-block"/>
<rect x="260" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="270" y="350" width="10" height="40" class="fp8-block"/>
<rect x="40" y="270" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
<text x="160" y="296" class="dots-text"></text>
<text x="160" y="376" class="dots-text"></text>
<text x="90" y="336" class="dots-text" transform="rotate(90 90 336)"></text>
<text x="230" y="336" class="dots-text" transform="rotate(90 230 336)"></text>
<text x="160" y="336" class="dots-text" transform="rotate(45 160 336)"></text>
</g>
<!-- Columnwise Scale Tensor - TRANSPOSED -->
<g id="colwise-scales">
<text x="485" y="255" class="small-text">Scales</text>
<!-- Row 1 -->
<rect x="370" y="300" width="10" height="10" class="scale-block"/>
<rect x="380" y="300" width="10" height="10" class="scale-block"/>
<rect x="390" y="300" width="10" height="10" class="scale-block"/>
<rect x="400" y="300" width="10" height="10" class="scale-block"/>
<rect x="410" y="300" width="10" height="10" class="scale-block"/>
<rect x="420" y="300" width="10" height="10" class="scale-block"/>
<rect x="430" y="300" width="10" height="10" class="scale-block"/>
<rect x="440" y="300" width="10" height="10" class="scale-block"/>
<rect x="450" y="300" width="10" height="10" class="scale-block"/>
<rect x="460" y="300" width="10" height="10" class="scale-block"/>
<rect x="510" y="300" width="10" height="10" class="scale-block"/>
<rect x="520" y="300" width="10" height="10" class="scale-block"/>
<rect x="530" y="300" width="10" height="10" class="scale-block"/>
<rect x="540" y="300" width="10" height="10" class="scale-block"/>
<rect x="550" y="300" width="10" height="10" class="scale-block"/>
<rect x="560" y="300" width="10" height="10" class="scale-block"/>
<rect x="570" y="300" width="10" height="10" class="scale-block"/>
<rect x="580" y="300" width="10" height="10" class="scale-block"/>
<rect x="590" y="300" width="10" height="10" class="scale-block"/>
<rect x="600" y="300" width="10" height="10" class="scale-block"/>
<!-- Row 2 (gap) -->
<rect x="370" y="330" width="10" height="10" class="scale-block"/>
<rect x="380" y="330" width="10" height="10" class="scale-block"/>
<rect x="390" y="330" width="10" height="10" class="scale-block"/>
<rect x="400" y="330" width="10" height="10" class="scale-block"/>
<rect x="410" y="330" width="10" height="10" class="scale-block"/>
<rect x="420" y="330" width="10" height="10" class="scale-block"/>
<rect x="430" y="330" width="10" height="10" class="scale-block"/>
<rect x="440" y="330" width="10" height="10" class="scale-block"/>
<rect x="450" y="330" width="10" height="10" class="scale-block"/>
<rect x="460" y="330" width="10" height="10" class="scale-block"/>
<rect x="510" y="330" width="10" height="10" class="scale-block"/>
<rect x="520" y="330" width="10" height="10" class="scale-block"/>
<rect x="530" y="330" width="10" height="10" class="scale-block"/>
<rect x="540" y="330" width="10" height="10" class="scale-block"/>
<rect x="550" y="330" width="10" height="10" class="scale-block"/>
<rect x="560" y="330" width="10" height="10" class="scale-block"/>
<rect x="570" y="330" width="10" height="10" class="scale-block"/>
<rect x="580" y="330" width="10" height="10" class="scale-block"/>
<rect x="590" y="330" width="10" height="10" class="scale-block"/>
<rect x="600" y="330" width="10" height="10" class="scale-block"/>
<rect x="370" y="300" width="240" height="40" fill="none" stroke="#444" stroke-width="2"/>
<text x="490" y="320" class="dots-text"></text>
<text x="430" y="320" class="dots-text"></text>
<text x="560" y="320" class="dots-text"></text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 -15 900 445" width="100%" style="max-width: 900px;">
<style>
.scale-fill { fill: #FFA500; stroke: #444; stroke-width: 2; }
.scale-fill-nostroke { fill: #FFA500; stroke: none; }
.grid-line { stroke: #444; stroke-width: 2; fill: none; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 12px sans-serif; fill: #333; text-anchor: middle; }
.num-text { font: bold 12px sans-serif; fill: #333; text-anchor: middle; dominant-baseline: middle; }
.arrow-line { stroke: #444; stroke-width: 2; fill: none; }
.outer-border { fill: none; stroke: #444; stroke-width: 1.5; }
.inner-line { stroke: #444; stroke-width: 1; stroke-dasharray: 3,2; }
.num-text-small { font: bold 11px sans-serif; fill: #333; text-anchor: middle; dominant-baseline: middle; }
.dots-text-small { font: bold 14px sans-serif; fill: #333; text-anchor: middle; dominant-baseline: middle; }
</style>
<defs>
<marker id="arrowhead" markerWidth="10" markerHeight="10" refX="8" refY="5" orient="auto">
<polygon points="0 0, 10 5, 0 10" fill="#444" />
</marker>
</defs>
<!-- ======== PART 1: Linearization (from mxfp8_scale_linearize.svg) ======== -->
<g id="linearization">
<!-- Left: Scaling factors grid -->
<!-- Main rectangle with white background -->
<rect x="40" y="40" width="120" height="240" fill="#FFFFFF" stroke="#444" stroke-width="2"/>
<!-- Orange grid area (upper-left) -->
<rect x="40" y="40" width="60" height="180" fill="#FFA500" stroke="#444" stroke-width="2"/>
<!-- Grid lines for narrow blocks (3 columns x 3 rows) -->
<line x1="60" y1="40" x2="60" y2="220" class="grid-line"/>
<line x1="80" y1="40" x2="80" y2="220" class="grid-line"/>
<line x1="40" y1="100" x2="100" y2="100" class="grid-line"/>
<line x1="40" y1="160" x2="100" y2="160" class="grid-line"/>
<!-- Numbers in orange cells -->
<!-- Row 1 -->
<text x="50" y="70" class="num-text">1</text>
<text x="70" y="70" class="num-text">2</text>
<text x="90" y="70" class="num-text">3</text>
<!-- Row 2 -->
<text x="50" y="125" class="num-text">K</text>
<text x="50" y="137" class="num-text">+</text>
<text x="50" y="149" class="num-text">1</text>
<text x="70" y="125" class="num-text">K</text>
<text x="70" y="137" class="num-text">+</text>
<text x="70" y="149" class="num-text">2</text>
<text x="90" y="125" class="num-text">K</text>
<text x="90" y="137" class="num-text">+</text>
<text x="90" y="149" class="num-text">3</text>
<!-- Row 3 -->
<text x="50" y="180" class="num-text">2K</text>
<text x="50" y="192" class="num-text">+</text>
<text x="50" y="204" class="num-text">1</text>
<text x="70" y="180" class="num-text">2K</text>
<text x="70" y="192" class="num-text">+</text>
<text x="70" y="204" class="num-text">1</text>
<text x="90" y="180" class="num-text">2K</text>
<text x="90" y="192" class="num-text">+</text>
<text x="90" y="204" class="num-text">3</text>
<!-- Dots in white area (right side) -->
<text x="125" y="90" class="dots-text"></text>
<text x="125" y="150" class="dots-text"></text>
<text x="125" y="210" class="dots-text"></text>
<!-- Dots in white area (bottom) -->
<text x="60" y="260" class="dots-text"></text>
<text x="90" y="260" class="dots-text"></text>
<text x="125" y="260" class="dots-text"></text>
<!-- Arrow pointing to first block with label 128x4 -->
<text x="50" y="0" class="label" text-anchor="middle">128x4</text>
<path d="M 50 5 L 50 38" class="arrow-line" marker-end="url(#arrowhead)"/>
<!-- Arrow -->
<path d="M 200 150 L 300 150" class="arrow-line" marker-end="url(#arrowhead)"/>
<!-- Right: Linearized 1D array -->
<!-- Main rectangle with white background -->
<rect x="340" y="140" width="520" height="20" fill="#FFFFFF" stroke="#444" stroke-width="2"/>
<!-- Orange cells -->
<rect x="340" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="375" y="150" class="num-text">1</text>
<rect x="410" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="445" y="150" class="num-text">2</text>
<rect x="480" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="515" y="150" class="dots-text"></text>
<rect x="550" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="585" y="150" class="num-text">K + 1</text>
<rect x="620" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="655" y="150" class="num-text">K + 2</text>
<!-- White area with dots -->
<text x="725" y="150" class="dots-text"></text>
<!-- Arrow pointing to first linearized block with label 1x512 -->
<text x="375" y="75" class="label" text-anchor="middle">1x512</text>
<path d="M 375 80 L 375 138" class="arrow-line" marker-end="url(#arrowhead)"/>
</g>
<!-- ======== Connection: Arrow from "1" to bottom left block with brace ======== -->
<g id="connection">
<!-- Label above brace -->
<text x="250" y="335" class="label">128 4-bit elements</text>
<!-- Curly brace on top of the bottom left block -->
<path d="M 115 355
Q 115 345, 125 345
L 245 345
Q 250 345, 250 340
Q 250 345, 255 345
L 375 345
Q 385 345, 385 355"
fill="none" stroke="#444" stroke-width="2"/>
<!-- Arrow from "1" cell down to the center of the brace -->
<path d="M 375 175 Q 375 260, 250 315" class="arrow-line" marker-end="url(#arrowhead)"/>
</g>
<!-- ======== PART 2: Swizzling (from mxfp8_swizzle_indices.svg) ======== -->
<!-- Offset by 330 (300 + 30px gap) -->
<g id="swizzling" transform="translate(100, 330)">
<!-- Left: Sequential indices -->
<g id="sequential">
<!-- Background -->
<rect x="15" y="35" width="270" height="30" class="scale-fill-nostroke"/>
<rect x="15" y="35" width="270" height="30" class="outer-border"/>
<!-- Dashed internal lines -->
<line x1="45" y1="35" x2="45" y2="65" class="inner-line"/>
<line x1="75" y1="35" x2="75" y2="65" class="inner-line"/>
<line x1="105" y1="35" x2="105" y2="65" class="inner-line"/>
<line x1="135" y1="35" x2="135" y2="65" class="inner-line"/>
<line x1="165" y1="35" x2="165" y2="65" class="inner-line"/>
<line x1="195" y1="35" x2="195" y2="65" class="inner-line"/>
<line x1="225" y1="35" x2="225" y2="65" class="inner-line"/>
<line x1="255" y1="35" x2="255" y2="65" class="inner-line"/>
<!-- Numbers -->
<text x="30" y="50" class="num-text-small">0</text>
<text x="60" y="50" class="num-text-small">1</text>
<text x="90" y="50" class="num-text-small">2</text>
<text x="120" y="50" class="num-text-small">3</text>
<text x="150" y="50" class="num-text-small">4</text>
<text x="180" y="50" class="num-text-small">5</text>
<text x="210" y="50" class="num-text-small">6</text>
<text x="240" y="50" class="num-text-small">7</text>
<text x="270" y="50" class="dots-text-small">...</text>
</g>
<!-- Arrow -->
<path d="M 300 50 L 340 50" class="arrow-line" marker-end="url(#arrowhead)"/>
<!-- Right: Swizzled indices -->
<g id="swizzled">
<!-- Background -->
<rect x="360" y="35" width="270" height="30" class="scale-fill-nostroke"/>
<rect x="360" y="35" width="270" height="30" class="outer-border"/>
<!-- Dashed internal lines -->
<line x1="390" y1="35" x2="390" y2="65" class="inner-line"/>
<line x1="420" y1="35" x2="420" y2="65" class="inner-line"/>
<line x1="450" y1="35" x2="450" y2="65" class="inner-line"/>
<line x1="480" y1="35" x2="480" y2="65" class="inner-line"/>
<line x1="510" y1="35" x2="510" y2="65" class="inner-line"/>
<line x1="540" y1="35" x2="540" y2="65" class="inner-line"/>
<line x1="570" y1="35" x2="570" y2="65" class="inner-line"/>
<line x1="600" y1="35" x2="600" y2="65" class="inner-line"/>
<!-- Numbers -->
<text x="375" y="50" class="num-text-small">0</text>
<text x="405" y="50" class="num-text-small">32</text>
<text x="435" y="50" class="num-text-small">64</text>
<text x="465" y="50" class="num-text-small">96</text>
<text x="495" y="50" class="num-text-small">1</text>
<text x="525" y="50" class="num-text-small">33</text>
<text x="555" y="50" class="num-text-small">65</text>
<text x="585" y="50" class="num-text-small">97</text>
<text x="615" y="50" class="dots-text-small">...</text>
</g>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1230 220" width="100%" style="max-width: 900px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
/* Diagram-specific styles */
.input-box { fill: #f3e5f5; stroke: #7b1fa2; stroke-width: 2.5; }
.mxfp8-box { fill: #e3f2fd; stroke: #1976d2; stroke-width: 2.5; }
.fp8-tile { fill: #bbdefb; stroke: #1565c0; stroke-width: 1.5; }
.scale-tile { fill: #a5d6a7; stroke: #388e3c; stroke-width: 1.5; }
.scale-swizzled { fill: #ffb74d; stroke: #e65100; stroke-width: 1.5; }
.swizzle-box { fill: #fff3e0; stroke: #f57c00; stroke-width: 2; }
.quantize-box { fill: #ede7f6; stroke: #5e35b1; stroke-width: 2; }
.comm-box { fill: #fff9c4; stroke: #f57f17; stroke-width: 2; }
.gemm-box { fill: #c8e6c9; stroke: #388e3c; stroke-width: 2; }
/* Arrow override */
.arrow { marker-end: url(#arrowhead); }
</style>
<!-- Arrow marker -->
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- MXFP8 Complete Flow -->
<g id="complete-flow">
<!-- Step 0: Input Tensor -->
<g id="input-fp32-tensor">
<text x="80" y="25" class="text" text-anchor="middle" font-weight="600">Input Tensor</text>
<rect x="20" y="40" width="120" height="150" rx="6" class="input-box"/>
<text x="80" y="120" class="text" text-anchor="middle" fill="#fff" font-weight="600">FP32/BF16</text>
</g>
<!-- Arrow 0 -->
<path d="M 140 115 L 180 115" class="arrow"/>
<!-- Step 1: Quantize -->
<rect x="180" y="75" width="90" height="80" rx="6" class="quantize-box"/>
<text x="225" y="120" class="text" font-weight="600">Quantize</text>
<!-- Arrow 1 -->
<path d="M 270 115 L 310 115" class="arrow"/>
<!-- Step 2: MXFP8 Tensor with sub-tiles stacked vertically -->
<g id="mxfp8-tensor">
<text x="410" y="25" class="text" text-anchor="middle" font-weight="600">MXFP8 Tensor</text>
<rect x="310" y="40" width="200" height="150" rx="6" class="mxfp8-box"/>
<!-- Scales sub-tile (green) - on top -->
<rect x="330" y="55" width="160" height="40" rx="3" class="scale-tile"/>
<text x="410" y="80" class="text" text-anchor="middle" fill="#fff" font-weight="600">Scales</text>
<!-- FP8 Data sub-tile - on bottom -->
<rect x="330" y="105" width="160" height="70" rx="3" class="fp8-tile"/>
<text x="410" y="145" class="text" fill="#fff" font-weight="600">FP8 Data</text>
</g>
<!-- Arrow 2 -->
<path d="M 510 115 L 560 115" class="arrow"/>
<!-- Step 3: Communication -->
<rect x="560" y="75" width="130" height="80" rx="6" class="comm-box"/>
<text x="625" y="110" class="text" font-weight="600">Communication</text>
<text x="625" y="125" class="text" font-size="12">(All-Gather)</text>
<text x="625" y="140" class="text" font-size="12" font-style="italic">(Optional)</text>
<!-- Arrow 3 -->
<path d="M 690 115 L 740 115" class="arrow"/>
<!-- Step 4: Swizzle -->
<rect x="740" y="75" width="110" height="80" rx="6" class="swizzle-box"/>
<text x="795" y="120" class="text" font-weight="600">Swizzle</text>
<!-- Arrow 4 -->
<path d="M 850 115 L 900 115" class="arrow"/>
<!-- Step 5: MXFP8 Tensor with swizzled scales -->
<g id="swizzled-tensor">
<text x="980" y="25" class="text" text-anchor="middle" font-weight="600">MXFP8 Tensor</text>
<rect x="900" y="40" width="160" height="150" rx="6" class="mxfp8-box"/>
<!-- Swizzled Scales sub-tile (orange) - on top -->
<rect x="915" y="55" width="130" height="40" rx="3" class="scale-swizzled"/>
<text x="980" y="80" class="text" text-anchor="middle" fill="#fff" font-weight="600">Swizzle Scales</text>
<!-- FP8 Data sub-tile (unchanged) - on bottom -->
<rect x="915" y="105" width="130" height="70" rx="3" class="fp8-tile"/>
<text x="980" y="145" class="text" fill="#fff" font-weight="600">FP8 Data</text>
</g>
<!-- Arrow 5 -->
<path d="M 1060 115 L 1110 115" class="arrow"/>
<!-- Step 6: GEMM -->
<rect x="1110" y="75" width="110" height="80" rx="6" class="gemm-box"/>
<text x="1165" y="120" class="text" font-weight="600">GEMM</text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 700 300" width="100%" style="max-width: 700px;">
<style>
.tensor-fill { fill: #87CEEB; stroke: #444; stroke-width: 2; }
.scale-fill { fill: #FFA500; stroke: #444; stroke-width: 2; }
.grid-line { stroke: #444; stroke-width: 2; fill: none; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 12px sans-serif; fill: #333; text-anchor: middle; }
</style>
<!-- Left tensor (128x128 blocks) - FP8 tensor -->
<!-- Main rectangle with white background -->
<rect x="60" y="40" width="260" height="240" fill="#FFFFFF" stroke="#444" stroke-width="2"/>
<!-- Blue grid area (upper-left) -->
<rect x="60" y="40" width="180" height="180" fill="#87CEEB" stroke="#444" stroke-width="2"/>
<!-- Grid lines for 3x3 blocks -->
<line x1="120" y1="40" x2="120" y2="220" class="grid-line"/>
<line x1="180" y1="40" x2="180" y2="220" class="grid-line"/>
<line x1="60" y1="100" x2="240" y2="100" class="grid-line"/>
<line x1="60" y1="160" x2="240" y2="160" class="grid-line"/>
<!-- Dots in white area (right side) -->
<text x="280" y="90" class="dots-text"></text>
<text x="280" y="150" class="dots-text"></text>
<text x="280" y="210" class="dots-text"></text>
<!-- Dots in white area (bottom) -->
<text x="90" y="260" class="dots-text"></text>
<text x="150" y="260" class="dots-text"></text>
<text x="210" y="260" class="dots-text"></text>
<text x="280" y="260" class="dots-text"></text>
<!-- Label -->
<text x="190" y="20" class="label">FP8 Tensor (128×128 blocks)</text>
<!-- Right tensor (128x4 blocks) - Scaling factors (orange) -->
<!-- Main rectangle with white background -->
<rect x="480" y="40" width="120" height="240" fill="#FFFFFF" stroke="#444" stroke-width="2"/>
<!-- Orange grid area (upper-left) -->
<rect x="480" y="40" width="60" height="180" fill="#FFA500" stroke="#444" stroke-width="2"/>
<!-- Grid lines for narrow blocks (3 columns x 3 rows) -->
<line x1="500" y1="40" x2="500" y2="220" class="grid-line"/>
<line x1="520" y1="40" x2="520" y2="220" class="grid-line"/>
<line x1="480" y1="100" x2="540" y2="100" class="grid-line"/>
<line x1="480" y1="160" x2="540" y2="160" class="grid-line"/>
<!-- Dots in white area (right side) -->
<text x="565" y="90" class="dots-text"></text>
<text x="565" y="150" class="dots-text"></text>
<text x="565" y="210" class="dots-text"></text>
<!-- Dots in white area (bottom) -->
<text x="500" y="260" class="dots-text"></text>
<text x="530" y="260" class="dots-text"></text>
<text x="565" y="260" class="dots-text"></text>
<!-- Label -->
<text x="540" y="20" class="label">Scaling Factors (128×4 blocks)</text>
</svg>
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Check for Blackwell or newer GPU
from transformer_engine.jax.quantize import get_device_compute_capability
assert (
get_device_compute_capability() >= 100
), f"MXFP8 requires SM100 (Blackwell) or later, got SM{get_device_compute_capability()}"
# START_MXFP8_EXAMPLE
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import DenseGeneral
from transformer_engine.common.recipe import MXFP8BlockScaling, Format
# Create MXFP8 recipe
recipe = MXFP8BlockScaling(
fp8_format=Format.E4M3, # FP8 format (default: E4M3, E5M2 not supported)
)
with te.autocast(enabled=True, recipe=recipe):
# Initialize layer and data
layer = DenseGeneral(features=1024)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
var_collect = layer.init(key, x)
# Forward and backward pass
def loss_fn(var_collect):
output = layer.apply(var_collect, x)
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
# END_MXFP8_EXAMPLE
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
MXFP8
=====
MXFP8 (Microscaling FP8) is an enhanced FP8 blockwise scaling recipe that leverages native hardware
acceleration on Blackwell GPUs (SM 10.0+). By using one scaling factor per 32 consecutive values
(rather than 128), MXFP8 delivers finer-grained quantization with improved numerical precision.
Data Format
-----------
The representation of an FP8 tensor element ``x`` in MXFP8 precision is given by:
.. code-block:: python
x = x_fp8 * s_block
where
* ``x_fp8`` is the FP8 value in E4M3 format,
* ``s_block`` is a local **E8M0** scaling factor shared by a block of 32 elements.
E8M0 is an 8-bit format with 8 exponent bits and 0 mantissa bits, representing only powers of 2.
**FP8 format**
Like FP8 Blockwise Scaling, E4M3 is used by default for both forward and backward passes.
The finer-grained scaling provides sufficient dynamic range without requiring the E5M2 format.
The ``fp8_format`` parameter also supports ``HYBRID`` mode (E4M3 for forward, E5M2 for backward).
Pure E5M2 training is not supported.
**Block size**
Block size is 32.
Blocks are one-dimensional, containing 32 consecutive values. No 2D scaling is performed.
There are some assumptions on the dimensions of the tensor:
* the tensor must have at least 2 dimensions,
* the last dimension must be divisible by 32,
* the product of all dimensions except the last must be divisible by 32.
**Scaling factors**
Scaling factors are stored as E8M0 (8 exponent bits, 0 mantissa bits), which inherently represents
powers of 2. This differs from FP8 Blockwise Scaling, which uses 32-bit floating point numbers
optionally constrained to powers of 2. Note that FP32 also has 8 exponent bits, so the representable
ranges are the same when the power-of-2 constraint is enabled.
Each block's scaling factor is computed through the following steps:
1. Find the maximum absolute value (``amax_block``) across all 32 elements in the block.
2. Compute the E8M0 biased exponent: ``e = float_to_e8m0(amax_block / max_fp8)``, where ``max_fp8 = 448``
(the maximum representable value in E4M3 format).
Since E8M0 and FP32 share the same exponent bias (127), ``float_to_e8m0`` simply extracts
the 8-bit exponent from the FP32 representation, rounding up if the mantissa is non-zero.
3. The scaling factor is ``s_block = 2^(e - 127)``.
This ensures that the largest value in each block fits within the FP8 representable range without overflow.
.. raw:: html
:file: img/fp8_1d_scaling.svg
*Figure 1. MXFP8 uses one E8M0 scaling factor per 32 consecutive elements, providing fine-grained
quantization and compact scaling factor representation.*
Handling transposes
-------------------
Blackwell architecture supports multiple FP8 GEMM layouts (TN, NT, NN), so columnwise usage
does not require explicit transposition. However, rowwise and columnwise quantizations are different:
- *Rowwise* - 1 scaling factor per 32 consecutive elements along a row (1×32 blocks).
- *Columnwise* - 1 scaling factor per 32 consecutive elements along a column (32×1 blocks).
Since the scaling factor blocks have different orientations, rowwise and columnwise MXFP8 tensors
are numerically different — one cannot derive one from the other. Both must be quantized
independently from the full-precision data.
.. raw:: html
:file: img/mxfp8_row_col.svg
*Figure 2. MXFP8 rowwise vs columnwise quantization layout.*
Distributed training
--------------------
**Scale synchronization**
The blockwise scaled tensor does not need any scale synchronization among the nodes.
This is because each scaling factor is local to its 32-element block,
unlike :doc:`FP8 Current <../fp8_current_scaling/fp8_current_scaling>`/:doc:`Delayed Scaling <../fp8_delayed_scaling/fp8_delayed_scaling>` where a single global scale applies to the entire tensor, even when sharded.
**Quantized all-gather**
MXFP8 all-gather is supported.
Examples
--------
Here's how to use MXFP8 recipe in PyTorch and JAX:
.. tabs::
.. tab:: PyTorch
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Requires SM100 (Blackwell) or later
</div>
.. literalinclude:: pytorch_mxfp8_example.py
:language: python
:start-after: # START_MXFP8_EXAMPLE
:end-before: # END_MXFP8_EXAMPLE
.. tab:: JAX
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Requires SM100 (Blackwell) or later
</div>
.. literalinclude:: jax_mxfp8_example.py
:language: python
:start-after: # START_MXFP8_EXAMPLE
:end-before: # END_MXFP8_EXAMPLE
Supported devices
-----------------
SM 10.0, SM 10.3
----
Developer Notes
---------------
This section contains implementation details that may be useful for developers
but are not required for using MXFP8 in practice.
Swizzling scaling factors
^^^^^^^^^^^^^^^^^^^^^^^^^
Like :doc:`FP8 Blockwise Scaling <../fp8_blockwise_scaling/fp8_blockwise_scaling>`, MXFP8 uses different data layouts for communication and computation.
MXFP8 GEMMs require scaling factors in a specific hardware layout
(see `cuBLAS documentation <https://docs.nvidia.com/cuda/cublas/index.html#block-scaling-factors-layout>`__).
The conversion to this GEMM-ready layout is called *swizzling*. When no communication is needed,
swizzling can be fused with quantization. When communication is required, swizzled scaling factors
cannot be communicated across devices, so Transformer Engine performs swizzling after communication,
just before each GEMM operation.
.. raw:: html
:file: img/mxfp8_swizzle_both_tensors.svg
*Figure 3. MXFP8 swizzling process: standard scaling factors are rearranged into the hardware-required layout.*
Blackwell Tensor Cores compute matrix multiplications using ``128x128`` tiles.
Scaling factors are stored in row-major order, but to process a tile, we need a ``128x4`` vertical
slice of scaling factors. In row-major storage, these vertical slices are scattered in memory
with gaps between each row. The hardware requires them to be stored contiguously.
.. raw:: html
:file: img/mxfp8_tensor_scaling_layout.svg
*Figure 4. FP8 tensor (left) is divided into 128x128 tiles. Each tile requires a 128x4 block of scaling factors (right). These vertical blocks are not contiguous in memory.*
Swizzling transforms the layout to meet hardware requirements by:
1. **Linearizing** the ``128x4`` blocks so they are stored contiguously one after another.
2. **Permuting** the 4-byte elements within each block.
Specifically, if we index the 128 4-byte elements in a scaling factor block as :math:`0, 1, \dots, 127`, the hardware expects them in the following interleaved order:
.. code-block:: text
0, 32, 64, 96, 1, 33, 65, 97, ..., k, 32 + k, 64 + k, 96 + k, ..., 31, 63, 95, 127
.. raw:: html
:file: img/mxfp8_scale_linearize_and_swizzle.svg
*Figure 5. Linearization and swizzling of scaling factors. The 2D grid of scaling factors is first flattened into a contiguous sequence of blocks (top), then the rows within each block are interleaved to match the hardware access pattern (bottom).*
For columnwise scaling factors, the process is analogous but with ``4x128`` horizontal blocks instead of ``128x4`` vertical blocks.
All-gather of columnwise tensors
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
All-gather of columnwise tensors is supported and necessary because:
- columnwise quantized tensors cannot be computed from rowwise quantized ones,
- gathering high-precision tensors is avoided in most cases for performance reasons.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment