Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 900 220">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.arrow {
stroke: #616161;
stroke-width: 2;
fill: none;
marker-end: url(#arrowhead-cast);
}
</style>
<marker id="arrowhead-cast" markerWidth="10" markerHeight="10" refX="8" refY="3" orient="auto" markerUnits="strokeWidth">
<polygon points="0 0, 10 3, 0 6" fill="#616161" />
</marker>
</defs>
<!-- Title -->
<text x="450" y="30" class="title" text-anchor="middle">FP8 quantization </text>
<!-- Step 1: High Precision Tensor -->
<rect x="80" y="80" width="140" height="70" class="hp" rx="6"/>
<text x="150" y="110" class="text" text-anchor="middle">High Precision</text>
<text x="150" y="130" class="text" text-anchor="middle">Tensor</text>
<!-- Arrow 1 -->
<path d="M 220 115 L 270 115" class="arrow"/>
<!-- Quantize container box -->
<rect x="270" y="60" width="330" height="130" class="quantize" rx="6"/>
<text x="435" y="205" class="text" style="font-weight: 600; font-size: 14px;" text-anchor="middle">Quantize</text>
<!-- Step 2: Compute Amax (sub-box) -->
<rect x="280" y="95" width="140" height="50" class="amax" rx="4"/>
<text x="350" y="118" class="text" style="font-weight: 600;" text-anchor="middle">Compute amax</text>
<text x="350" y="160" class="small-text" text-anchor="middle">1 tensor read</text>
<!-- Arrow 2 (inside quantize box) -->
<path d="M 420 120 L 450 120" class="arrow"/>
<!-- Step 3: Apply Scale + Cast (sub-box) -->
<rect x="450" y="95" width="140" height="50" class="quantize" rx="4"/>
<text x="520" y="115" class="text" style="font-weight: 600;" text-anchor="middle">Apply Scale</text>
<text x="520" y="130" class="text" style="font-weight: 600;" text-anchor="middle">+ Cast</text>
<text x="520" y="160" class="small-text" text-anchor="middle">1 tensor read</text>
<!-- Arrow 3 -->
<path d="M 600 115 L 650 115" class="arrow"/>
<!-- Step 4: FP8 Tensor -->
<rect x="650" y="80" width="140" height="70" class="fp8" rx="6"/>
<text x="720" y="110" class="text" text-anchor="middle">FP8</text>
<text x="720" y="130" class="text" text-anchor="middle">Tensor</text>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 950 170" width="950" height="170">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
/* Arrows */
.arrow { stroke: #616161; stroke-width: 2; fill: none; marker-end: url(#arrowhead-ag); }
/* All-gather operations - fallback if CSS doesn't load */
.allgather {
fill: #e1f5fe;
stroke: #039be5;
stroke-width: 2;
}
</style>
<marker id="arrowhead-ag" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- Title -->
<text x="475" y="30" class="title">Quantization + all gather for FP8 current scaling</text>
<!-- High Precision Tensor -->
<rect x="30" y="80" width="110" height="55" class="hp" rx="6"/>
<text x="85" y="103" class="text">High Precision</text>
<text x="85" y="120" class="text">Tensor</text>
<!-- Arrow -->
<path d="M 140 107 L 165 107" class="arrow"/>
<!-- Compute Amax -->
<rect x="165" y="80" width="100" height="55" class="amax" rx="6"/>
<text x="215" y="103" class="text">Compute</text>
<text x="215" y="120" class="text">Amax</text>
<!-- Arrow -->
<path d="M 265 107 L 290 107" class="arrow"/>
<!-- Synchronize Amax -->
<rect x="290" y="80" width="100" height="55" class="amax" rx="6"/>
<text x="340" y="103" class="text">Synchronize</text>
<text x="340" y="120" class="text">Amax</text>
<!-- Arrow -->
<path d="M 390 107 L 415 107" class="arrow"/>
<!-- Scale + Cast -->
<rect x="415" y="80" width="100" height="55" class="quantize" rx="6"/>
<text x="465" y="103" class="text">Scale +</text>
<text x="465" y="120" class="text">Cast</text>
<!-- Arrow -->
<path d="M 515 107 L 540 107" class="arrow"/>
<!-- FP8 Tensor (intermediate) -->
<rect x="540" y="80" width="100" height="55" class="fp8" rx="6"/>
<text x="590" y="103" class="text">FP8</text>
<text x="590" y="120" class="text">Tensor</text>
<!-- Arrow -->
<path d="M 640 107 L 665 107" class="arrow"/>
<!-- All-Gather -->
<rect x="665" y="80" width="100" height="55" class="allgather" rx="6"/>
<text x="715" y="112" class="text">All-Gather</text>
<!-- Arrow -->
<path d="M 765 107 L 790 107" class="arrow"/>
<!-- FP8 Gathered Tensor -->
<rect x="790" y="80" width="130" height="55" class="fp8" rx="6"/>
<text x="855" y="103" class="text">FP8 Gathered</text>
<text x="855" y="120" class="text">Tensor</text>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 900 280">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.sign-bit { fill: #9db4d0; stroke: #333; stroke-width: 1; }
.exponent-bit { fill: #d9a066; stroke: #333; stroke-width: 1; }
.mantissa-bit { fill: #a8d99c; stroke: #333; stroke-width: 1; }
.bit-text { fill: #000; text-anchor: middle; dominant-baseline: middle; font-size: 16px; }
.header-text { fill: #555; font-weight: normal; text-anchor: middle; font-size: 18px; }
.value-text { fill: #333; font-size: 18px; }
.format-label { fill: #333; font-weight: bold; text-anchor: middle; dominant-baseline: middle; font-size: 20px; }
</style>
</defs>
<!-- Header labels - centered -->
<text x="149" y="18" class="header-text">sign</text>
<text x="220" y="18" class="header-text">exponent</text>
<text x="420" y="18" class="header-text">mantissa</text>
<!-- FP16 Format (16 bits: 1 + 5 + 10) -->
<text x="60" y="60" class="format-label">FP16</text>
<!-- Sign bit (1) -->
<rect x="140" y="45" width="18" height="30" class="sign-bit"/>
<text x="149" y="60" class="bit-text">0</text>
<!-- Exponent bits (5) -->
<rect x="163" y="45" width="18" height="30" class="exponent-bit"/>
<text x="172" y="60" class="bit-text">0</text>
<rect x="186" y="45" width="18" height="30" class="exponent-bit"/>
<text x="195" y="60" class="bit-text">1</text>
<rect x="209" y="45" width="18" height="30" class="exponent-bit"/>
<text x="218" y="60" class="bit-text">1</text>
<rect x="232" y="45" width="18" height="30" class="exponent-bit"/>
<text x="241" y="60" class="bit-text">0</text>
<rect x="255" y="45" width="18" height="30" class="exponent-bit"/>
<text x="264" y="60" class="bit-text">1</text>
<!-- Mantissa bits (10) -->
<rect x="278" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="287" y="60" class="bit-text">1</text>
<rect x="301" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="310" y="60" class="bit-text">0</text>
<rect x="324" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="333" y="60" class="bit-text">0</text>
<rect x="347" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="356" y="60" class="bit-text">1</text>
<rect x="370" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="379" y="60" class="bit-text">0</text>
<rect x="393" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="402" y="60" class="bit-text">1</text>
<rect x="416" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="425" y="60" class="bit-text">0</text>
<rect x="439" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="448" y="60" class="bit-text">0</text>
<rect x="462" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="471" y="60" class="bit-text">1</text>
<rect x="485" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="494" y="60" class="bit-text">1</text>
<text x="540" y="60" class="value-text">= 0.395264</text>
<!-- BF16 Format (16 bits: 1 + 8 + 7) -->
<text x="60" y="120" class="format-label">BF16</text>
<!-- Sign bit (1) -->
<rect x="140" y="105" width="18" height="30" class="sign-bit"/>
<text x="149" y="120" class="bit-text">0</text>
<!-- Exponent bits (8) -->
<rect x="163" y="105" width="18" height="30" class="exponent-bit"/>
<text x="172" y="120" class="bit-text">0</text>
<rect x="186" y="105" width="18" height="30" class="exponent-bit"/>
<text x="195" y="120" class="bit-text">1</text>
<rect x="209" y="105" width="18" height="30" class="exponent-bit"/>
<text x="218" y="120" class="bit-text">1</text>
<rect x="232" y="105" width="18" height="30" class="exponent-bit"/>
<text x="241" y="120" class="bit-text">1</text>
<rect x="255" y="105" width="18" height="30" class="exponent-bit"/>
<text x="264" y="120" class="bit-text">1</text>
<rect x="278" y="105" width="18" height="30" class="exponent-bit"/>
<text x="287" y="120" class="bit-text">1</text>
<rect x="301" y="105" width="18" height="30" class="exponent-bit"/>
<text x="310" y="120" class="bit-text">0</text>
<rect x="324" y="105" width="18" height="30" class="exponent-bit"/>
<text x="333" y="120" class="bit-text">1</text>
<!-- Mantissa bits (7) -->
<rect x="347" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="356" y="120" class="bit-text">1</text>
<rect x="370" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="379" y="120" class="bit-text">0</text>
<rect x="393" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="402" y="120" class="bit-text">0</text>
<rect x="416" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="425" y="120" class="bit-text">1</text>
<rect x="439" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="448" y="120" class="bit-text">0</text>
<rect x="462" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="471" y="120" class="bit-text">1</text>
<rect x="485" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="494" y="120" class="bit-text">0</text>
<text x="540" y="120" class="value-text">= 0.394531</text>
<!-- FP8 E4M3 Format (8 bits: 1 + 4 + 3) -->
<text x="60" y="180" class="format-label">FP8 E4M3</text>
<!-- Sign bit (1) -->
<rect x="140" y="165" width="18" height="30" class="sign-bit"/>
<text x="149" y="180" class="bit-text">0</text>
<!-- Exponent bits (4) -->
<rect x="163" y="165" width="18" height="30" class="exponent-bit"/>
<text x="172" y="180" class="bit-text">0</text>
<rect x="186" y="165" width="18" height="30" class="exponent-bit"/>
<text x="195" y="180" class="bit-text">1</text>
<rect x="209" y="165" width="18" height="30" class="exponent-bit"/>
<text x="218" y="180" class="bit-text">0</text>
<rect x="232" y="165" width="18" height="30" class="exponent-bit"/>
<text x="241" y="180" class="bit-text">1</text>
<!-- Mantissa bits (3) -->
<rect x="255" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="264" y="180" class="bit-text">1</text>
<rect x="278" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="287" y="180" class="bit-text">0</text>
<rect x="301" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="310" y="180" class="bit-text">1</text>
<text x="355" y="180" class="value-text">= 0.40625</text>
<!-- FP8 E5M2 Format (8 bits: 1 + 5 + 2) -->
<text x="60" y="240" class="format-label">FP8 E5M2</text>
<!-- Sign bit (1) -->
<rect x="140" y="225" width="18" height="30" class="sign-bit"/>
<text x="149" y="240" class="bit-text">0</text>
<!-- Exponent bits (5) -->
<rect x="163" y="225" width="18" height="30" class="exponent-bit"/>
<text x="172" y="240" class="bit-text">0</text>
<rect x="186" y="225" width="18" height="30" class="exponent-bit"/>
<text x="195" y="240" class="bit-text">1</text>
<rect x="209" y="225" width="18" height="30" class="exponent-bit"/>
<text x="218" y="240" class="bit-text">1</text>
<rect x="232" y="225" width="18" height="30" class="exponent-bit"/>
<text x="241" y="240" class="bit-text">0</text>
<rect x="255" y="225" width="18" height="30" class="exponent-bit"/>
<text x="264" y="240" class="bit-text">1</text>
<!-- Mantissa bits (2) -->
<rect x="278" y="225" width="18" height="30" class="mantissa-bit"/>
<text x="287" y="240" class="bit-text">1</text>
<rect x="301" y="225" width="18" height="30" class="mantissa-bit"/>
<text x="310" y="240" class="bit-text">0</text>
<text x="355" y="240" class="value-text">= 0.375</text>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 900 380">
<style>
@import url("../_static/css/diagram-colors.css");
.axis-line { stroke: #333; stroke-width: 2.5; }
.value-dot { fill: #2196f3; stroke: #1976d2; stroke-width: 1; }
.arrow { fill: #4caf50; }
.arrow-line { stroke: #4caf50; stroke-width: 3; }
.range-label { font-size: 14px; fill: #555; font-weight: 500; }
</style>
<!-- Top: Original values (before scaling) -->
<text x="450" y="55" class="section-title" text-anchor="middle">Original Tensor Values</text>
<!-- Top axis -->
<line x1="80" y1="85" x2="820" y2="85" class="axis-line"/>
<!-- Zero marker (center) -->
<line x1="450" y1="80" x2="450" y2="90" stroke="#333" stroke-width="2"/>
<text x="450" y="108" class="text" text-anchor="middle" font-size="12px">0</text>
<!-- Value dots (before scaling - irregular, not symmetric around zero) -->
<circle cx="118" cy="85" r="6" fill="#e53935" stroke="#c62828" stroke-width="2"/>
<circle cx="159" cy="85" r="5" class="value-dot"/>
<circle cx="167" cy="85" r="5" class="value-dot"/>
<circle cx="187" cy="85" r="5" class="value-dot"/>
<circle cx="199" cy="85" r="5" class="value-dot"/>
<circle cx="228" cy="85" r="5" class="value-dot"/>
<circle cx="326" cy="85" r="5" class="value-dot"/>
<circle cx="368" cy="85" r="5" class="value-dot"/>
<circle cx="442" cy="85" r="5" class="value-dot"/>
<circle cx="621" cy="85" r="5" class="value-dot"/>
<circle cx="649" cy="85" r="5" class="value-dot"/>
<circle cx="725" cy="85" r="5" class="value-dot"/>
<!-- amax label -->
<text x="118" y="70" class="text" fill="#e53935" font-weight="700" font-size="14px" text-anchor="middle">amax</text>
<!-- Original range bracket spanning all values -->
<line x1="118" y1="100" x2="118" y2="110" stroke="#666" stroke-width="1.5"/>
<line x1="118" y1="110" x2="725" y2="110" stroke="#666" stroke-width="1.5"/>
<line x1="725" y1="100" x2="725" y2="110" stroke="#666" stroke-width="1.5"/>
<text x="750" y="114" class="range-label" text-anchor="start">Original range</text>
<!-- Trapezoid showing compression from original range to FP8 range -->
<polygon points="118,115 725,115 650,165 250,165" fill="#e53935" opacity="0.2" stroke="#e53935" stroke-width="1.5"/>
<!-- Bottom: After scaling -->
<text x="450" y="190" class="section-title" text-anchor="middle">Scaled Values (fit FP8 range)</text>
<!-- Bottom axis -->
<line x1="80" y1="220" x2="820" y2="220" class="axis-line"/>
<!-- Zero marker (center) -->
<line x1="450" y1="215" x2="450" y2="225" stroke="#333" stroke-width="2"/>
<text x="450" y="238" class="text" text-anchor="middle" font-size="12px">0</text>
<!-- FP8 range bracket -->
<line x1="250" y1="245" x2="250" y2="255" stroke="#4caf50" stroke-width="1.5"/>
<line x1="250" y1="255" x2="650" y2="255" stroke="#4caf50" stroke-width="1.5"/>
<line x1="650" y1="245" x2="650" y2="255" stroke="#4caf50" stroke-width="1.5"/>
<text x="750" y="259" class="range-label" text-anchor="start" fill="#4caf50">FP8 range</text>
<!-- Value dots (after scaling - homogeneous scaling from zero, all fit into FP8 range) -->
<circle cx="250" cy="220" r="6" fill="#e53935" stroke="#c62828" stroke-width="2"/>
<text x="250" y="205" class="text" fill="#e53935" font-weight="700" font-size="12px" text-anchor="middle">- FP8 range max</text>
<circle cx="275" cy="220" r="5" class="value-dot"/>
<circle cx="280" cy="220" r="5" class="value-dot"/>
<circle cx="292" cy="220" r="5" class="value-dot"/>
<circle cx="299" cy="220" r="5" class="value-dot"/>
<circle cx="316" cy="220" r="5" class="value-dot"/>
<circle cx="375" cy="220" r="5" class="value-dot"/>
<circle cx="401" cy="220" r="5" class="value-dot"/>
<circle cx="445" cy="220" r="5" class="value-dot"/>
<circle cx="553" cy="220" r="5" class="value-dot"/>
<circle cx="569" cy="220" r="5" class="value-dot"/>
<circle cx="615" cy="220" r="5" class="value-dot"/>
<!-- Third line: After cast to FP8 (quantized values) -->
<text x="450" y="290" class="section-title" text-anchor="middle">Cast to FP8 (quantized values)</text>
<!-- Third axis -->
<line x1="80" y1="320" x2="820" y2="320" class="axis-line"/>
<!-- Zero marker (center) -->
<line x1="450" y1="315" x2="450" y2="325" stroke="#333" stroke-width="2"/>
<text x="450" y="338" class="text" text-anchor="middle" font-size="12px">0</text>
<!-- FP8 range bracket -->
<line x1="250" y1="345" x2="250" y2="355" stroke="#4caf50" stroke-width="1.5"/>
<line x1="250" y1="355" x2="650" y2="355" stroke="#4caf50" stroke-width="1.5"/>
<line x1="650" y1="345" x2="650" y2="355" stroke="#4caf50" stroke-width="1.5"/>
<text x="750" y="359" class="range-label" text-anchor="start" fill="#4caf50">FP8 range</text>
<!-- Quantized dots - merged close values to show FP8 granularity -->
<circle cx="250" cy="320" r="6" fill="#e53935" stroke="#c62828" stroke-width="2"/>
<!-- merged: 275+280 -->
<circle cx="278" cy="317" r="4.5" class="value-dot"/>
<circle cx="278" cy="323" r="4.5" class="value-dot"/>
<!-- merged: 292+299 -->
<circle cx="296" cy="317" r="4.5" class="value-dot"/>
<circle cx="296" cy="323" r="4.5" class="value-dot"/>
<circle cx="318" cy="320" r="5" class="value-dot"/>
<circle cx="378" cy="320" r="5" class="value-dot"/>
<circle cx="404" cy="320" r="5" class="value-dot"/>
<circle cx="450" cy="320" r="5" class="value-dot"/>
<!-- merged: 553+569 -->
<circle cx="562" cy="317" r="4.5" class="value-dot"/>
<circle cx="562" cy="323" r="4.5" class="value-dot"/>
<circle cx="615" cy="320" r="5" class="value-dot"/>
</svg>
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_CURRENT_SCALING_EXAMPLE
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import DenseGeneral
from transformer_engine.common.recipe import Float8CurrentScaling, Format
# Create FP8 Current Scaling recipe
# Available formats:
# - Format.HYBRID (default) -- E4M3 for forward pass, E5M2 for backward pass
# - Format.E4M3 -- E4M3 for both forward and backward pass
recipe = Float8CurrentScaling(fp8_format=Format.HYBRID)
with te.autocast(enabled=True, recipe=recipe):
# Create and initialize layer
layer = DenseGeneral(features=1024)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
var_collect = layer.init(key, x)
# Forward and backward pass
def loss_fn(var_collect):
output = layer.apply(var_collect, x)
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
# END_CURRENT_SCALING_EXAMPLE
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_CURRENT_SCALING_EXAMPLE
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Float8CurrentScaling, Format
# Create FP8 Current Scaling recipe
# Available formats:
# - Format.HYBRID (default) -- E4M3 for forward pass, E5M2 for backward pass
# - Format.E4M3 -- E4M3 for both forward and backward pass
recipe = Float8CurrentScaling(fp8_format=Format.HYBRID)
# Create a simple linear layer with bfloat16 parameters
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
# Forward and backward pass
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
loss = output.sum()
loss.backward()
# END_CURRENT_SCALING_EXAMPLE
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
FP8 Delayed Scaling
===================================
FP8 Delayed Scaling recipe estimates scaling factors from historical amax values rather than computing them
for each tensor. Compared to Current Scaling recipe,
this reduces tensor reads per quantization from two to one,
improving memory efficiency.
Both this and :doc:`FP8 Current Scaling <../fp8_current_scaling/fp8_current_scaling>` recipe use
the same FP8 formats (E4M3/E5M2) with one FP32 scaling factor per tensor.
Reading the FP8 Current Scaling documentation first is recommended.
Quantization with delayed scaling factors
-----------------------------------------
FP8 Current Scaling requires two tensor reads per quantization: one to compute amax,
one to cast. FP8 Delayed Scaling eliminates the first read by predicting the scaling factor
from historical amax values - hence *delayed* (using past values) versus *current* (using present values).
The quantization process works as follows:
1. **Compute scaling factor from history** (no tensor read needed):
The scaling factor is derived from stored ``amax_history`` using the formula:
``scaling_factor = FP8_MAX / amax``
where ``amax`` is computed from history using either ``max`` (maximum over window, default) or ``most_recent`` algorithm.
2. **Quantize the tensor** (one tensor read):
Apply the scaling factor and cast to FP8. Values exceeding FP8 range are clipped.
3. **Update history**:
Record the actual amax from this quantization for future iterations.
Each module maintains an ``amax_history`` tensor of configurable length (``amax_history_len``)
for each quantized tensor.
.. raw:: html
:file: img/scaling_comparison.svg
*Figure 1. Comparison of FP8 Current Scaling and FP8 Delayed Scaling quantization processes.*
Amax History Management
-----------------------
The ``amax_history`` buffer acts as a sliding window of recent amax values.
Position 0 serves as a staging area for the current amax, while positions 1 to N-1
store the history from oldest to newest. Each quantization writes the observed amax
to position 0, and after the pass completes, the history is rotated:
.. code-block:: text
Before rotation: [amax_N, amax_1, amax_2, ..., amax_N-1] (amax_N = current, amax_1 = oldest)
After rotation: [0, amax_2, ..., amax_N-1, amax_N] (amax_1 dropped, amax_N appended)
The scaling factor is computed **before** the rotation, so it uses all ``amax_history_len`` values.
Position 0 serves as a staging area — it is zeroed after the scale update, ready for the next iteration's amax.
The implementation differs between PyTorch and JAX:
.. tabs::
.. tab:: PyTorch
Each module creates two ``amax_history`` tensors, initialized to zero:
- Forward: shape ``(amax_history_len, num_gemms * 3)`` — three FP8 tensors per GEMM (input, weight, output)
- Backward: shape ``(amax_history_len, num_gemms * 2)`` — two FP8 tensors per GEMM (grad_output, grad_input)
When the autocast context exits, a single CUDA kernel processes all tensors at once —
performing amax reduction across GPUs and history rotation. This batched approach
minimizes kernel launch overhead compared to updating each tensor separately.
.. tab:: JAX
Each quantizer maintains its own ``amax_history`` with shape ``(amax_history_len,)``
and updates independently.
Here's how to use FP8 Delayed Scaling in PyTorch and JAX:
.. tabs::
.. tab:: PyTorch
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Requires SM89 (Ada) or later
</div>
.. literalinclude:: pytorch_delayed_scaling_example.py
:language: python
:start-after: # START_DELAYED_SCALING_EXAMPLE
:end-before: # END_DELAYED_SCALING_EXAMPLE
.. tab:: JAX
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Requires SM89 (Ada) or later
</div>
.. literalinclude:: jax_delayed_scaling_example.py
:language: python
:start-after: # START_DELAYED_SCALING_EXAMPLE
:end-before: # END_DELAYED_SCALING_EXAMPLE
Distributed Training
--------------------
FP8 Delayed Scaling uses the same data formats as FP8 Current Scaling - quantized all-gather is supported.
However, amax reduction works slightly differently in different frameworks.
.. tabs::
.. tab:: PyTorch
Amax reduction is controlled by two parameters:
- ``reduce_amax`` in recipe: enables/disables reduction (required for SP and CP)
- ``amax_reduction_group`` in ``autocast``: specifies the process group for reduction
We recommend reducing amax across all GPUs where the tensor is sharded,
including data parallel ranks.
.. literalinclude:: pytorch_delayed_scaling_distributed_example.py
:language: python
:start-after: # START_AMAX_REDUCTION_EXAMPLE
:end-before: # END_AMAX_REDUCTION_EXAMPLE
In data parallel training, some modules may not execute on certain ranks
(e.g., MoE experts that receive no tokens). This is handled as follows:
- **First iteration**: All modules must execute on all ranks to register
their ``amax_history`` tensors in the global buffer. Mismatched registration
would cause the ``all_reduce`` to hang due to different tensor sizes across ranks.
- **Subsequent iterations**: The ``autocast`` context must be entered and exited
on all ranks (this triggers the collective reduction). Individual modules can be
skipped - if no rank executes a module, its history is not rotated and scale
remains unchanged.
.. tab:: JAX
Amax reduction is always enabled and managed automatically.
Reduction scope: all parallelism axes except pipeline parallelism (TP, SP, DP/FSDP).
.. literalinclude:: jax_delayed_scaling_distributed_example.py
:language: python
:start-after: # START_AMAX_REDUCTION_EXAMPLE
:end-before: # END_AMAX_REDUCTION_EXAMPLE
Supported devices
-----------------
Ada and later (SM 8.9+)
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1000 420">
<defs>
<style>
/* Common styles loaded from diagram-colors.css: .hp, .fp8, .quantize, .amax, .text, .title, .label, .box-orange, .box-dashed */
/* Diagram-specific styles for arrows */
.arrow {
stroke: #616161;
stroke-width: 2;
fill: none;
marker-end: url(#arrowhead);
}
</style>
<marker id="arrowhead" markerWidth="10" markerHeight="10" refX="8" refY="3" orient="auto" markerUnits="strokeWidth">
<polygon points="0 0, 10 3, 0 6" fill="#616161" />
</marker>
</defs>
<!-- Current Scaling Section -->
<text x="250" y="30" class="title">Current Scaling</text>
<!-- Tensor box -->
<rect x="150" y="60" width="200" height="60" class="hp" rx="5"/>
<text x="250" y="95" class="text">Tensor</text>
<!-- Arrow to amax computation -->
<path d="M 250 120 L 250 160" class="arrow"/>
<!-- Amax computation box -->
<rect x="150" y="160" width="200" height="60" class="amax" rx="5"/>
<text x="250" y="195" class="text">Amax Computation</text>
<!-- Arrow to quantization -->
<path d="M 250 220 L 250 260" class="arrow"/>
<!-- Quantization box -->
<rect x="125" y="260" width="250" height="60" class="quantize" rx="5"/>
<text x="250" y="285" class="text">Quantization</text>
<text x="250" y="305" class="label">(uses tensor + amax)</text>
<!-- Arrow to FP8 tensor -->
<path d="M 250 320 L 250 360" class="arrow"/>
<!-- FP8 Tensor result -->
<rect x="150" y="360" width="200" height="40" class="fp8" rx="5"/>
<text x="250" y="385" class="text">FP8 Tensor</text>
<!-- Delayed Scaling Section -->
<text x="750" y="30" class="title">Delayed Scaling</text>
<!-- Tensor box with amax history subbox -->
<rect x="650" y="60" width="200" height="80" class="hp" rx="5"/>
<text x="750" y="90" class="text">Tensor</text>
<!-- Amax history subbox (below tensor) -->
<rect x="660" y="110" width="180" height="25" class="box-orange box-dashed" rx="3"/>
<text x="750" y="127" class="label">amax history</text>
<!-- Arrow to quantization -->
<path d="M 750 140 L 750 180" class="arrow"/>
<text x="820" y="162" class="small-text" style="text-anchor: start;">read amax</text>
<!-- Quantization box -->
<rect x="625" y="180" width="250" height="80" class="quantize" rx="5"/>
<text x="750" y="210" class="text">Quantization</text>
<text x="750" y="230" class="label">(uses tensor + amax from history)</text>
<text x="750" y="250" class="label">(updates amax history)</text>
<!-- Arrow back to history (curved) -->
<path d="M 625 220 Q 590 220 590 127 L 660 127" class="arrow"/>
<text x="565" y="175" class="small-text" style="text-anchor: end;">update amax</text>
<!-- Arrow to FP8 tensor -->
<path d="M 750 260 L 750 300" class="arrow"/>
<!-- FP8 Tensor result -->
<rect x="650" y="300" width="200" height="40" class="fp8" rx="5"/>
<text x="750" y="325" class="text">FP8 Tensor</text>
</svg>
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_AMAX_REDUCTION_EXAMPLE
import transformer_engine.jax as te
from transformer_engine.common.recipe import DelayedScaling
# Amax reduction scope is managed internally
recipe = DelayedScaling(reduce_amax=True) # Must be True in JAX
with te.autocast(enabled=True, recipe=recipe):
output = layer.apply(params, inp)
# END_AMAX_REDUCTION_EXAMPLE
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from transformer_engine.jax.quantize import get_device_compute_capability
# Requires Ada (SM89) or newer for FP8 support
assert get_device_compute_capability() >= 89, "This example requires SM89 (Ada) or newer"
# START_DELAYED_SCALING_EXAMPLE
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import DenseGeneral
from transformer_engine.common.recipe import DelayedScaling
# Create FP8 Delayed Scaling recipe
recipe = DelayedScaling(
margin=0, # Margin for scaling factor computation (default: 0)
amax_history_len=1024, # Length of amax history window (default: 1024)
amax_compute_algo="max", # How to compute amax from history (default: "max")
)
with te.autocast(enabled=True, recipe=recipe):
# Initialize layer and data
layer = DenseGeneral(features=1024)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
var_collect = layer.init(key, x)
# Forward and backward pass
def loss_fn(var_collect):
output = layer.apply(var_collect, x)
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
# END_DELAYED_SCALING_EXAMPLE
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_AMAX_REDUCTION_EXAMPLE
import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
# Create process group for amax reduction (e.g., all 8 GPUs)
amax_reduction_group = dist.new_group(ranks=[0, 1, 2, 3, 4, 5, 6, 7])
recipe = DelayedScaling(reduce_amax=True)
with te.autocast(recipe=recipe, amax_reduction_group=amax_reduction_group):
output = model(inp)
# END_AMAX_REDUCTION_EXAMPLE
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Requires Ada (SM89) or newer for FP8 support
assert torch.cuda.get_device_capability()[0] >= 9 or (
torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9
), "This example requires SM89 (Ada) or newer"
# START_DELAYED_SCALING_EXAMPLE
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
# Create FP8 Delayed Scaling recipe
recipe = DelayedScaling(
margin=0, # Margin for scaling factor computation (default: 0)
amax_history_len=1024, # Length of amax history window (default: 1024)
amax_compute_algo="max", # How to compute amax from history (default: "max")
)
# Create a linear layer with bfloat16 parameters
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
# Forward and backward pass
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
loss = output.sum()
loss.backward()
# END_DELAYED_SCALING_EXAMPLE
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Low precision training
===================================
.. toctree::
introduction/introduction.rst
performance_considerations/performance_considerations.rst
fp8_current_scaling/fp8_current_scaling.rst
fp8_delayed_scaling/fp8_delayed_scaling.rst
fp8_blockwise_scaling/fp8_blockwise_scaling.rst
mxfp8/mxfp8.rst
nvfp4/nvfp4.rst
\ No newline at end of file
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from transformer_engine.jax.quantize import get_device_compute_capability
# Requires Ada (SM89) or newer for FP8 support
assert get_device_compute_capability() >= 89, "This example requires SM89 (Ada) or newer"
# START_AUTOCAST_BASIC
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import TransformerLayer
from transformer_engine.common.recipe import DelayedScaling, Format
# Set up recipe
recipe = DelayedScaling()
# Model initialization must happen inside autocast
with te.autocast(enabled=True, recipe=recipe):
layer = TransformerLayer(
hidden_size=1024,
mlp_hidden_size=4096,
num_attention_heads=16,
)
init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(init_key, (32, 128, 1024), dtype=jnp.bfloat16)
var_collect = layer.init({"params": init_key, "dropout": dropout_key}, x)
# Forward and backward pass (both inside autocast for JAX)
def loss_fn(var_collect):
output = layer.apply(var_collect, x, rngs={"dropout": dropout_key})
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
# END_AUTOCAST_BASIC
# START_AUTOCAST_SEQUENTIAL
encoder_recipe = DelayedScaling(fp8_format=Format.E4M3)
decoder_recipe = DelayedScaling(fp8_format=Format.HYBRID)
with te.autocast(enabled=True, recipe=encoder_recipe):
encoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
encoder_var_collect = encoder.init({"params": init_key, "dropout": dropout_key}, x)
hidden = encoder.apply(encoder_var_collect, x, rngs={"dropout": dropout_key})
with te.autocast(enabled=True, recipe=decoder_recipe):
decoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
decoder_var_collect = decoder.init({"params": init_key, "dropout": dropout_key}, hidden)
output = decoder.apply(decoder_var_collect, hidden, rngs={"dropout": dropout_key})
# END_AUTOCAST_SEQUENTIAL
# START_AUTOCAST_NESTED
outer_recipe = DelayedScaling(fp8_format=Format.E4M3)
inner_recipe = DelayedScaling(fp8_format=Format.HYBRID)
with te.autocast(enabled=True, recipe=outer_recipe):
# layer1 uses outer_recipe
layer1 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
var_collect1 = layer1.init({"params": init_key, "dropout": dropout_key}, x)
hidden = layer1.apply(var_collect1, x, rngs={"dropout": dropout_key})
with te.autocast(enabled=True, recipe=inner_recipe):
# layer2 uses inner_recipe (overrides outer)
layer2 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
var_collect2 = layer2.init({"params": init_key, "dropout": dropout_key}, hidden)
hidden = layer2.apply(var_collect2, hidden, rngs={"dropout": dropout_key})
# layer3 uses outer_recipe again
layer3 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
var_collect3 = layer3.init({"params": init_key, "dropout": dropout_key}, hidden)
output = layer3.apply(var_collect3, hidden, rngs={"dropout": dropout_key})
# END_AUTOCAST_NESTED
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Requires Ada (SM89) or newer for FP8 support
assert torch.cuda.get_device_capability()[0] >= 9 or (
torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9
), "This example requires SM89 (Ada) or newer"
# START_AUTOCAST_BASIC
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format
recipe = DelayedScaling()
layer = te.Linear(1024, 1024)
inp = torch.randn(32, 1024, dtype=torch.float32, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
# .backward() is called outside of autocast
loss = output.sum()
loss.backward()
# END_AUTOCAST_BASIC
# START_AUTOCAST_SEQUENTIAL
encoder_recipe = DelayedScaling(fp8_format=Format.E4M3)
decoder_recipe = DelayedScaling(fp8_format=Format.HYBRID)
encoder = te.Linear(1024, 1024)
decoder = te.Linear(1024, 1024)
with te.autocast(enabled=True, recipe=encoder_recipe):
hidden = encoder(inp)
with te.autocast(enabled=True, recipe=decoder_recipe):
output = decoder(hidden)
# END_AUTOCAST_SEQUENTIAL
# START_AUTOCAST_NESTED
outer_recipe = DelayedScaling(fp8_format=Format.E4M3)
inner_recipe = DelayedScaling(fp8_format=Format.HYBRID)
layer1 = te.Linear(1024, 1024)
layer2 = te.Linear(1024, 1024)
layer3 = te.Linear(1024, 1024)
with te.autocast(enabled=True, recipe=outer_recipe):
# layer1 uses outer_recipe
x = layer1(inp)
with te.autocast(enabled=True, recipe=inner_recipe):
# layer2 uses inner_recipe (overrides outer)
x = layer2(x)
# layer3 uses outer_recipe again
output = layer3(x)
# END_AUTOCAST_NESTED
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_BF16_FP16_TRAINING
import jax
import jax.numpy as jnp
from transformer_engine.jax.flax import TransformerLayer
def run_forward_backward(params_dtype, compute_dtype):
# Create TransformerLayer
layer = TransformerLayer(
hidden_size=1024,
mlp_hidden_size=4096,
num_attention_heads=16,
dtype=params_dtype,
)
# Initialize parameters and optimizer
init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(init_key, (32, 128, 1024), dtype=compute_dtype)
var_collect = layer.init({"params": init_key, "dropout": dropout_key}, x)
# Forward and backward pass
def loss_fn(var_collect):
output = layer.apply(var_collect, x, rngs={"dropout": dropout_key})
assert output.dtype == compute_dtype
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
run_forward_backward(jnp.float32, jnp.float32) # high precision training
run_forward_backward(jnp.float32, jnp.bfloat16) # bfloat16 training with master weights in FP32
run_forward_backward(jnp.bfloat16, jnp.bfloat16) # bfloat16 training with weights in BF16
# END_BF16_FP16_TRAINING
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_BF16_FP16_TRAINING
import torch
import transformer_engine.pytorch as te
from contextlib import nullcontext
def run_forward_backward(params_dtype, autocast_precision, grad_scaler_enabled):
if grad_scaler_enabled:
grad_scaler = torch.amp.GradScaler("cuda")
layer = te.TransformerLayer(
hidden_size=1024,
ffn_hidden_size=4096,
num_attention_heads=16,
params_dtype=params_dtype,
)
x = torch.randn(32, 128, 1024, dtype=params_dtype, device="cuda")
autocast_ctx = (
torch.autocast(device_type="cuda", dtype=autocast_precision)
if autocast_precision is not None
else nullcontext()
)
with autocast_ctx:
output = layer(x)
assert (
output.dtype == autocast_precision if autocast_precision is not None else params_dtype
)
loss = output.sum()
if grad_scaler_enabled:
grad_scaler.scale(loss).backward()
else:
loss.backward()
run_forward_backward(torch.float32, torch.float32, False) # high precision training
run_forward_backward(
torch.float32, torch.bfloat16, False
) # bfloat16 training with master weights in FP32
run_forward_backward(
torch.float32, torch.float16, True
) # fp16 training with master weights in FP32, needs loss scaling
run_forward_backward(
torch.bfloat16, torch.bfloat16, False
) # bfloat16 training with weights in BF16
# END_BF16_FP16_TRAINING
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 850 780" width="850" height="780">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
</style>
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- Title -->
<text x="425" y="30" class="title">FP8 Linear Layer – Forward and Backward Pass</text>
<!-- Forward Pass Section -->
<text x="425" y="65" class="section-title" style="fill: #1565c0;">Forward Pass</text>
<!-- Forward: Input^T FP8 (top, saved for backward) -->
<rect x="270" y="70" width="80" height="50" class="fp8" rx="6"/>
<text x="310" y="100" class="text">Input<tspan baseline-shift="super" style="font-size: 9px;">T</tspan></text>
<!-- Forward: Input High Precision -->
<rect x="30" y="130" width="100" height="50" class="hp" rx="6"/>
<text x="80" y="160" class="text">Input</text>
<!-- Forward: Arrow -->
<path d="M 130 155 L 155 155" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Forward: Quantize Input -->
<rect x="155" y="130" width="90" height="50" class="quantize" rx="6"/>
<text x="200" y="160" class="text">Quantize</text>
<!-- Forward: Arrow to Input^T (going up) -->
<path d="M 245 140 L 270 110" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Forward: Arrow to Input -->
<path d="M 245 155 L 270 155" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Forward: Input FP8 -->
<rect x="270" y="130" width="80" height="50" class="fp8" rx="6"/>
<text x="310" y="160" class="text">Input</text>
<!-- Forward: Arrow from Input to GEMM -->
<path d="M 350 155 L 400 170" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<text x="370" y="145" class="text" style="font-size: 10px; font-weight: bold;">N</text>
<!-- Forward: Weights High Precision -->
<rect x="30" y="195" width="100" height="50" class="hp" rx="6"/>
<text x="80" y="225" class="text">Weight</text>
<!-- Forward: Arrow -->
<path d="M 130 220 L 155 220" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Forward: Quantize Weights -->
<rect x="155" y="195" width="90" height="50" class="quantize" rx="6"/>
<text x="200" y="225" class="text">Quantize</text>
<!-- Forward: Arrow to Weight -->
<path d="M 245 220 L 270 220" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Forward: Arrow to Weight^T (going down) -->
<path d="M 245 235 L 270 270" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Forward: Weights FP8 -->
<rect x="270" y="195" width="80" height="50" class="fp8" rx="6"/>
<text x="310" y="225" class="text">Weight</text>
<!-- Forward: Weight^T FP8 (bottom, saved for backward) -->
<rect x="270" y="255" width="80" height="50" class="fp8" rx="6"/>
<text x="310" y="285" class="text">Weight<tspan baseline-shift="super" style="font-size: 9px;">T</tspan></text>
<!-- Forward: Arrow from Weight to GEMM -->
<path d="M 350 220 L 400 200" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<text x="370" y="230" class="text" style="font-size: 10px; font-weight: bold;">T</text>
<!-- Forward: GEMM -->
<rect x="400" y="160" width="130" height="50" class="gemm" rx="6"/>
<text x="465" y="180" class="text" style="font-weight: 600;">FP8 GEMM</text>
<text x="465" y="200" class="text" style="font-size: 11px;">(TN)</text>
<!-- Forward: Arrow -->
<path d="M 530 185 L 580 185" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Forward: Output -->
<rect x="580" y="160" width="110" height="50" class="hp" rx="6"/>
<text x="635" y="190" class="text">Output</text>
<!-- Divider Line -->
<line x1="30" y1="310" x2="820" y2="310" stroke="#ddd" stroke-width="2"/>
<!-- Backward Pass Section -->
<text x="425" y="345" class="section-title" style="fill: #c62828;">Backward Pass</text>
<!-- Backward: Weight^T (from forward, top input to GEMM1) -->
<rect x="495" y="355" width="80" height="50" class="fp8" rx="6"/>
<text x="535" y="385" class="text">Weight<tspan baseline-shift="super" style="font-size: 9px;">T</tspan></text>
<!-- Backward: Output gradient High Precision -->
<rect x="30" y="480" width="130" height="50" class="hp" rx="6"/>
<text x="95" y="510" class="text">Output grad.</text>
<!-- Backward: Arrow -->
<path d="M 160 505 L 180 505" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Backward: Quantize Output gradient -->
<rect x="180" y="480" width="90" height="50" class="quantize" rx="6"/>
<text x="225" y="510" class="text">Quantize</text>
<!-- Backward: Arrow to Output grad (going up) -->
<path d="M 270 490 L 290 465" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Backward: Arrow to Output grad^T (going down) -->
<path d="M 270 520 L 290 545" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Backward: Output gradient FP8 (for input gradient) -->
<rect x="290" y="440" width="110" height="50" class="fp8" rx="6"/>
<text x="345" y="470" class="text">Output grad.</text>
<!-- Backward: Output gradient^T FP8 (for weight gradient) -->
<rect x="290" y="520" width="110" height="50" class="fp8" rx="6"/>
<text x="345" y="550" class="text">Output grad.<tspan baseline-shift="super" style="font-size: 9px;">T</tspan></text>
<!-- Backward: GEMM 1 (for input gradient) -->
<rect x="470" y="440" width="130" height="50" class="gemm" rx="6"/>
<text x="535" y="460" class="text" style="font-weight: 600;">FP8 GEMM</text>
<text x="535" y="480" class="text" style="font-size: 11px;">(TN)</text>
<!-- Backward: Input gradient -->
<rect x="640" y="440" width="130" height="50" class="hp" rx="6"/>
<text x="705" y="470" class="text">Input grad.</text>
<!-- Backward: GEMM 2 (for weight gradient) -->
<rect x="470" y="520" width="130" height="50" class="gemm" rx="6"/>
<text x="535" y="540" class="text" style="font-weight: 600;">FP8 GEMM</text>
<text x="535" y="560" class="text" style="font-size: 11px;">(TN)</text>
<!-- Backward: Weight gradient -->
<rect x="640" y="520" width="130" height="50" class="hp" rx="6"/>
<text x="705" y="550" class="text">Weight grad.</text>
<!-- Backward: Input^T (from forward, bottom input to GEMM2) -->
<rect x="495" y="605" width="80" height="50" class="fp8" rx="6"/>
<text x="535" y="635" class="text">Input<tspan baseline-shift="super" style="font-size: 9px;">T</tspan></text>
<!-- Backward: Arrows -->
<!-- Output gradient FP8 to top GEMM -->
<path d="M 400 465 L 470 465" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<text x="430" y="457" class="text" style="font-size: 10px; font-weight: bold;">N</text>
<!-- Weight^T to top GEMM -->
<path d="M 535 405 L 535 440" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<text x="543" y="427" class="text" style="font-size: 10px; font-weight: bold;">T</text>
<!-- Top GEMM to input gradient -->
<path d="M 600 465 L 640 465" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Output gradient^T FP8 to bottom GEMM -->
<path d="M 400 545 L 470 545" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<text x="430" y="537" class="text" style="font-size: 10px; font-weight: bold;">N</text>
<!-- Input^T to bottom GEMM -->
<path d="M 535 605 L 535 570" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<text x="543" y="597" class="text" style="font-size: 10px; font-weight: bold;">T</text>
<!-- Bottom GEMM to weight gradient -->
<path d="M 600 545 L 640 545" stroke="#616161" stroke-width="2" fill="none" marker-end="url(#arrowhead)"/>
<!-- Legend -->
<g transform="translate(30, 680)">
<!-- Higher Precision -->
<rect x="0" y="0" width="80" height="40" rx="5" class="hp"/>
<text x="95" y="23" class="text" style="text-anchor: start;">Higher Precision (FP32/BF16/FP16)</text>
<!-- Lower Precision -->
<rect x="380" y="0" width="80" height="40" rx="5" class="fp8"/>
<text x="475" y="23" class="text" style="text-anchor: start;">Lower Precision (FP8, MXFP8 etc.)</text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 900 210">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.sign-bit { fill: #9db4d0; stroke: #333; stroke-width: 1; }
.exponent-bit { fill: #d9a066; stroke: #333; stroke-width: 1; }
.mantissa-bit { fill: #a8d99c; stroke: #333; stroke-width: 1; }
.bit-text { fill: #000; text-anchor: middle; dominant-baseline: middle; font-size: 16px; }
.header-text { fill: #555; font-weight: normal; text-anchor: middle; font-size: 18px; }
.value-text { fill: #333; font-size: 18px; }
.format-label { fill: #333; font-size: 20px; font-weight: bold; text-anchor: middle; dominant-baseline: middle; }
</style>
</defs>
<!-- Header labels - centered -->
<text x="79" y="18" class="header-text">sign</text>
<text x="173" y="18" class="header-text">exponent</text>
<text x="530" y="18" class="header-text">mantissa</text>
<!-- FP32 Format (32 bits: 1 + 8 + 23) -->
<text x="30" y="60" class="format-label">FP32</text>
<!-- Sign bit (1) -->
<rect x="70" y="45" width="18" height="30" class="sign-bit"/>
<text x="79" y="60" class="bit-text">0</text>
<!-- Exponent bits (8) -->
<rect x="93" y="45" width="18" height="30" class="exponent-bit"/>
<text x="102" y="60" class="bit-text">0</text>
<rect x="116" y="45" width="18" height="30" class="exponent-bit"/>
<text x="125" y="60" class="bit-text">1</text>
<rect x="139" y="45" width="18" height="30" class="exponent-bit"/>
<text x="148" y="60" class="bit-text">1</text>
<rect x="162" y="45" width="18" height="30" class="exponent-bit"/>
<text x="171" y="60" class="bit-text">1</text>
<rect x="185" y="45" width="18" height="30" class="exponent-bit"/>
<text x="194" y="60" class="bit-text">1</text>
<rect x="208" y="45" width="18" height="30" class="exponent-bit"/>
<text x="217" y="60" class="bit-text">1</text>
<rect x="231" y="45" width="18" height="30" class="exponent-bit"/>
<text x="240" y="60" class="bit-text">0</text>
<rect x="254" y="45" width="18" height="30" class="exponent-bit"/>
<text x="263" y="60" class="bit-text">1</text>
<!-- Mantissa bits (23) -->
<rect x="277" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="286" y="60" class="bit-text">1</text>
<rect x="300" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="309" y="60" class="bit-text">0</text>
<rect x="323" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="332" y="60" class="bit-text">0</text>
<rect x="346" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="355" y="60" class="bit-text">1</text>
<rect x="369" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="378" y="60" class="bit-text">0</text>
<rect x="392" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="401" y="60" class="bit-text">1</text>
<rect x="415" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="424" y="60" class="bit-text">0</text>
<rect x="438" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="447" y="60" class="bit-text">0</text>
<rect x="461" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="470" y="60" class="bit-text">1</text>
<rect x="484" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="493" y="60" class="bit-text">0</text>
<rect x="507" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="516" y="60" class="bit-text">1</text>
<rect x="530" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="539" y="60" class="bit-text">0</text>
<rect x="553" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="562" y="60" class="bit-text">1</text>
<rect x="576" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="585" y="60" class="bit-text">1</text>
<rect x="599" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="608" y="60" class="bit-text">1</text>
<rect x="622" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="631" y="60" class="bit-text">1</text>
<rect x="645" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="654" y="60" class="bit-text">0</text>
<rect x="668" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="677" y="60" class="bit-text">1</text>
<rect x="691" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="700" y="60" class="bit-text">0</text>
<rect x="714" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="723" y="60" class="bit-text">1</text>
<rect x="737" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="746" y="60" class="bit-text">0</text>
<rect x="760" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="769" y="60" class="bit-text">0</text>
<rect x="783" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="792" y="60" class="bit-text">0</text>
<text x="820" y="60" class="value-text">= 0.3952</text>
<!-- BF16 Format (16 bits: 1 + 8 + 7) -->
<text x="30" y="120" class="format-label">BF16</text>
<!-- Sign bit (1) -->
<rect x="70" y="105" width="18" height="30" class="sign-bit"/>
<text x="79" y="120" class="bit-text">0</text>
<!-- Exponent bits (8) -->
<rect x="93" y="105" width="18" height="30" class="exponent-bit"/>
<text x="102" y="120" class="bit-text">0</text>
<rect x="116" y="105" width="18" height="30" class="exponent-bit"/>
<text x="125" y="120" class="bit-text">1</text>
<rect x="139" y="105" width="18" height="30" class="exponent-bit"/>
<text x="148" y="120" class="bit-text">1</text>
<rect x="162" y="105" width="18" height="30" class="exponent-bit"/>
<text x="171" y="120" class="bit-text">1</text>
<rect x="185" y="105" width="18" height="30" class="exponent-bit"/>
<text x="194" y="120" class="bit-text">1</text>
<rect x="208" y="105" width="18" height="30" class="exponent-bit"/>
<text x="217" y="120" class="bit-text">1</text>
<rect x="231" y="105" width="18" height="30" class="exponent-bit"/>
<text x="240" y="120" class="bit-text">0</text>
<rect x="254" y="105" width="18" height="30" class="exponent-bit"/>
<text x="263" y="120" class="bit-text">1</text>
<!-- Mantissa bits (7) -->
<rect x="277" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="286" y="120" class="bit-text">1</text>
<rect x="300" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="309" y="120" class="bit-text">0</text>
<rect x="323" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="332" y="120" class="bit-text">0</text>
<rect x="346" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="355" y="120" class="bit-text">1</text>
<rect x="369" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="378" y="120" class="bit-text">0</text>
<rect x="392" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="401" y="120" class="bit-text">1</text>
<rect x="415" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="424" y="120" class="bit-text">0</text>
<text x="820" y="120" class="value-text">≈ 0.3945</text>
<!-- FP16 Format (16 bits: 1 + 5 + 10) -->
<text x="30" y="180" class="format-label">FP16</text>
<!-- Sign bit (1) -->
<rect x="70" y="165" width="18" height="30" class="sign-bit"/>
<text x="79" y="180" class="bit-text">0</text>
<!-- Exponent bits (5) -->
<rect x="93" y="165" width="18" height="30" class="exponent-bit"/>
<text x="102" y="180" class="bit-text">0</text>
<rect x="116" y="165" width="18" height="30" class="exponent-bit"/>
<text x="125" y="180" class="bit-text">1</text>
<rect x="139" y="165" width="18" height="30" class="exponent-bit"/>
<text x="148" y="180" class="bit-text">1</text>
<rect x="162" y="165" width="18" height="30" class="exponent-bit"/>
<text x="171" y="180" class="bit-text">0</text>
<rect x="185" y="165" width="18" height="30" class="exponent-bit"/>
<text x="194" y="180" class="bit-text">1</text>
<!-- Mantissa bits (10) -->
<rect x="208" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="217" y="180" class="bit-text">1</text>
<rect x="231" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="240" y="180" class="bit-text">0</text>
<rect x="254" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="263" y="180" class="bit-text">0</text>
<rect x="277" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="286" y="180" class="bit-text">1</text>
<rect x="300" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="309" y="180" class="bit-text">0</text>
<rect x="323" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="332" y="180" class="bit-text">1</text>
<rect x="346" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="355" y="180" class="bit-text">0</text>
<rect x="369" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="378" y="180" class="bit-text">0</text>
<rect x="392" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="401" y="180" class="bit-text">1</text>
<rect x="415" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="424" y="180" class="bit-text">0</text>
<text x="820" y="180" class="value-text">≈ 0.3950</text>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1050 580" width="1050" height="580">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.arrow { stroke: #616161; stroke-width: 2; fill: none; marker-end: url(#arrowhead); }
.column-title { font-family: 'Segoe UI', Arial, sans-serif; font-size: 14px; font-weight: 600; text-anchor: middle; fill: #424242; }
.divider { stroke: #bdbdbd; stroke-width: 1.5; stroke-dasharray: 8,6; }
</style>
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- Title -->
<text x="525" y="30" class="title">Master Weights Storage Approaches</text>
<!-- Vertical dividers (dashed lines) -->
<line x1="350" y1="50" x2="350" y2="560" class="divider"/>
<line x1="700" y1="50" x2="700" y2="560" class="divider"/>
<!-- Column 1: Low Precision Only -->
<text x="175" y="75" class="column-title">Low Precision Weights</text>
<text x="175" y="93" class="small-text">(no master weights)</text>
<!-- Model box -->
<rect x="60" y="145" width="230" height="90" rx="6" fill="#f5f5f5" stroke="#9e9e9e" stroke-width="1.5"/>
<text x="175" y="168" class="label">Model</text>
<rect x="80" y="183" width="190" height="40" class="hp" rx="4"/>
<text x="175" y="208" class="text">Weights (BF16/FP16)</text>
<!-- Arrow down -->
<path d="M 175 235 L 175 300" class="arrow"/>
<!-- Computation -->
<rect x="90" y="300" width="170" height="50" class="gemm" rx="6"/>
<text x="175" y="330" class="text">Forward/Backward</text>
<!-- Arrow down -->
<path d="M 175 350 L 175 415" class="arrow"/>
<!-- Optimizer box -->
<rect x="60" y="415" width="230" height="90" rx="6" fill="#f5f5f5" stroke="#9e9e9e" stroke-width="1.5"/>
<text x="175" y="438" class="label">Optimizer</text>
<rect x="80" y="453" width="190" height="40" class="fp32" rx="4"/>
<text x="175" y="478" class="text">State (FP32)</text>
<!-- Column 2: Master Weights in Model -->
<text x="525" y="75" class="column-title">Master Weights in Model</text>
<!-- Model box -->
<rect x="410" y="145" width="230" height="90" rx="6" fill="#f5f5f5" stroke="#9e9e9e" stroke-width="1.5"/>
<text x="525" y="168" class="label">Model</text>
<rect x="430" y="183" width="190" height="40" class="fp32" rx="4"/>
<text x="525" y="208" class="text">Weights (FP32)</text>
<!-- Arrow down with cast -->
<path d="M 525 235 L 525 300" class="arrow"/>
<rect x="465" y="255" width="120" height="26" class="quantize" rx="4"/>
<text x="525" y="273" class="small-text">cast to BF16/FP16</text>
<!-- Computation -->
<rect x="440" y="300" width="170" height="50" class="gemm" rx="6"/>
<text x="525" y="330" class="text">Forward/Backward</text>
<!-- Arrow down -->
<path d="M 525 350 L 525 415" class="arrow"/>
<!-- Optimizer box -->
<rect x="410" y="415" width="230" height="90" rx="6" fill="#f5f5f5" stroke="#9e9e9e" stroke-width="1.5"/>
<text x="525" y="438" class="label">Optimizer</text>
<rect x="430" y="453" width="190" height="40" class="fp32" rx="4"/>
<text x="525" y="478" class="text">State (FP32)</text>
<!-- Column 3: Master Weights in Optimizer -->
<text x="875" y="75" class="column-title">Master Weights in Optimizer</text>
<!-- Cast box above Model -->
<rect x="815" y="105" width="120" height="26" class="quantize" rx="4"/>
<text x="875" y="123" class="small-text">cast to BF16/FP16</text>
<!-- Arrow from cast to Model -->
<path d="M 875 131 L 875 145" class="arrow"/>
<!-- Model box -->
<rect x="760" y="145" width="230" height="90" rx="6" fill="#f5f5f5" stroke="#9e9e9e" stroke-width="1.5"/>
<text x="875" y="168" class="label">Model</text>
<rect x="780" y="183" width="190" height="40" class="hp" rx="4"/>
<text x="875" y="208" class="text">Weights (BF16/FP16)</text>
<!-- Arrow down -->
<path d="M 875 235 L 875 300" class="arrow"/>
<!-- Computation -->
<rect x="790" y="300" width="170" height="50" class="gemm" rx="6"/>
<text x="875" y="330" class="text">Forward/Backward</text>
<!-- Arrow down -->
<path d="M 875 350 L 875 415" class="arrow"/>
<!-- Optimizer box with State and Master -->
<rect x="760" y="415" width="230" height="140" rx="6" fill="#f5f5f5" stroke="#9e9e9e" stroke-width="1.5"/>
<text x="875" y="438" class="label">Optimizer</text>
<rect x="780" y="453" width="190" height="40" class="fp32" rx="4"/>
<text x="875" y="478" class="text">State (FP32)</text>
<rect x="780" y="503" width="190" height="40" class="fp32" rx="4"/>
<text x="875" y="528" class="text">Master (FP32)</text>
<!-- Arrow from Master to cast -->
<path d="M 970 523 L 1010 523 L 1010 118 L 935 118" class="arrow"/>
</svg>
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment