Unverified Commit 3ceb248e authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

More detailed documentation for recipes (#2343)



* Code drop: Update recipes documentation and remove custom recipes from low precision training
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Fix SVG css import path for diagrams
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Refactor low_precision_training docs: remove optimizers, fix imports, add GPU checks

Changes:
- Remove optimizer code from all recipe examples (keep only forward/backward)
- Fix Format imports (use Format.E4M3 instead of string 'E4M3')
- Fix params_dtype for PyTorch examples (add params_dtype=torch.bfloat16)
- Add GPU capability assertions before START blocks for blockwise/mxfp8/nvfp4
- Fix JAX imports (Float8CurrentScaling from common.recipe, NVFP4BlockScaling)
- Add global_shard_guard for TransformerLayer examples in JAX
- Fix fused_layers_jax.py return tuple unpacking
- Update memory_usage JAX examples with dynamic GPU measurement
- Remove memory_usage_3_jax (JAX doesn't support FP8 weight storage)
- Update performance_considerations.rst for JAX differences
- Delete unused .out files and fp8_autocast_jax.py
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix JAX memory usage .out files with correct output
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* responded to comments
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* applied suggestions form greptile
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* year change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* jax compute capability fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c3769cb7
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Check for Blackwell or newer GPU
major, minor = torch.cuda.get_device_capability()
assert major >= 10, f"MXFP8 requires SM100 (Blackwell) or later, got SM{major}{minor}"
# START_MXFP8_EXAMPLE
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import MXFP8BlockScaling, Format
# Create MXFP8 recipe
recipe = MXFP8BlockScaling(
fp8_format=Format.E4M3, # E4M3 (default) or HYBRID; pure E5M2 not supported
)
# Create a linear layer with bfloat16 parameters
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
# Forward and backward pass
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
loss = output.sum()
loss.backward()
# END_MXFP8_EXAMPLE
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1100 140" width="1100" height="140">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 16px sans-serif; fill: #333; text-anchor: middle; }
.text { font: 13px sans-serif; fill: #333; text-anchor: middle; }
/* Arrows */
.arrow { stroke: #616161; stroke-width: 2; fill: none; marker-end: url(#arrowhead); }
/* High precision tensor */
.hp {
fill: #e8f5e9;
stroke: #43a047;
stroke-width: 2;
}
/* Amax operations */
.amax {
fill: #fff3e0;
stroke: #ff9800;
stroke-width: 2;
}
/* Quantize operations */
.quantize {
fill: #fce4ec;
stroke: #e91e63;
stroke-width: 2;
}
/* NVFP4 tensor */
.nvfp4 {
fill: #87CEEB;
stroke: #444;
stroke-width: 2;
}
/* All-gather operations */
.allgather {
fill: #e1f5fe;
stroke: #039be5;
stroke-width: 2;
}
</style>
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- Title -->
<text x="550" y="30" class="title">Quantization + All-Gather for NVFP4</text>
<!-- High Precision Tensor -->
<rect x="20" y="70" width="100" height="55" class="hp" rx="6"/>
<text x="70" y="93" class="text">High Precision</text>
<text x="70" y="110" class="text">Tensor</text>
<!-- Arrow -->
<path d="M 120 97 L 145 97" class="arrow"/>
<!-- Compute Amax -->
<rect x="145" y="70" width="100" height="55" class="amax" rx="6"/>
<text x="195" y="93" class="text">Compute</text>
<text x="195" y="110" class="text">Amax</text>
<!-- Arrow -->
<path d="M 245 97 L 270 97" class="arrow"/>
<!-- Synchronize Amax -->
<rect x="270" y="70" width="100" height="55" class="amax" rx="6"/>
<text x="320" y="93" class="text">Synchronize</text>
<text x="320" y="110" class="text">Amax</text>
<!-- Arrow -->
<path d="M 370 97 L 395 97" class="arrow"/>
<!-- Compute s_global -->
<rect x="395" y="70" width="100" height="55" class="amax" rx="6"/>
<text x="445" y="93" class="text">Compute</text>
<text x="445" y="110" class="text">s_global</text>
<!-- Arrow -->
<path d="M 495 97 L 520 97" class="arrow"/>
<!-- Scale + Cast -->
<rect x="520" y="70" width="100" height="55" class="quantize" rx="6"/>
<text x="570" y="86" class="text">Scale + Cast</text>
<text x="570" y="103" class="text">(s_block,</text>
<text x="570" y="118" class="text">s_global)</text>
<!-- Arrow -->
<path d="M 620 97 L 645 97" class="arrow"/>
<!-- NVFP4 Tensor (intermediate) -->
<rect x="645" y="70" width="100" height="55" class="nvfp4" rx="6"/>
<text x="695" y="93" class="text">NVFP4</text>
<text x="695" y="110" class="text">Tensor</text>
<!-- Arrow -->
<path d="M 745 97 L 770 97" class="arrow"/>
<!-- All-Gather -->
<rect x="770" y="70" width="100" height="55" class="allgather" rx="6"/>
<text x="820" y="102" class="text">All-Gather</text>
<!-- Arrow -->
<path d="M 870 97 L 895 97" class="arrow"/>
<!-- NVFP4 Gathered Tensor -->
<rect x="895" y="70" width="130" height="55" class="nvfp4" rx="6"/>
<text x="960" y="93" class="text">NVFP4 Gathered</text>
<text x="960" y="110" class="text">Tensor</text>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 60 500 450" width="100%" style="max-width: 500px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 18px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 14px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 12px sans-serif; fill: #555; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
/* Tensor colors */
.fp8-tensor { fill: #87CEEB; stroke: #444; stroke-width: 2; }
.fp8-block { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.fp8-block-alt { fill: #5F9FCC; stroke: #555; stroke-width: 1.5; }
/* Scaling factor colors */
.scale-factor { fill: #FFA500; stroke: #444; stroke-width: 2; }
.global-scale { fill: #FF6B6B; stroke: #444; stroke-width: 2; }
.grid-line { stroke: #444; stroke-width: 2; }
.boundary-line { stroke: #444; stroke-width: 2; }
</style>
</defs>
<!-- NVFP4 Scaling -->
<g id="nvfp4-scaling">
<text x="250" y="85" class="title">NVFP4 Hierarchical Scaling</text>
<text x="250" y="108" class="label" style="font-size: 12px;">(Block scaling + Global scaling)</text>
<!-- NVFP4 Tensor split into many small blocks (40×10) -->
<g id="tensor-blocks">
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="250.0" y="140.0" width="40" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="130.0" y="190.0" width="240" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Blocks ONLY where they don't overlap with white cross -->
<rect x="130" y="140" width="40" height="10" class="fp8-block"/>
<rect x="170" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="140" width="40" height="10" class="fp8-block"/>
<rect x="290" y="140" width="40" height="10" class="fp8-block"/>
<rect x="330" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="150" width="40" height="10" class="fp8-block"/>
<rect x="210" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="150" width="40" height="10" class="fp8-block"/>
<rect x="130" y="160" width="40" height="10" class="fp8-block"/>
<rect x="170" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="160" width="40" height="10" class="fp8-block"/>
<rect x="290" y="160" width="40" height="10" class="fp8-block"/>
<rect x="330" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="170" width="40" height="10" class="fp8-block"/>
<rect x="210" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="170" width="40" height="10" class="fp8-block"/>
<rect x="130" y="180" width="40" height="10" class="fp8-block"/>
<rect x="170" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="180" width="40" height="10" class="fp8-block"/>
<rect x="290" y="180" width="40" height="10" class="fp8-block"/>
<rect x="330" y="180" width="40" height="10" class="fp8-block-alt"/>
<!-- Three dots in VERTICAL white bar -->
<text x="270" y="167.5" class="dots-text"></text>
<text x="270" y="242.5" class="dots-text"></text>
<!-- Three dots in HORIZONTAL white bar -->
<text x="190" y="205" class="dots-text"></text>
<text x="330" y="205" class="dots-text"></text>
<!-- ONE diagonal dot at intersection -->
<text x="270" y="205" class="dots-text" transform="rotate(45 270 205)"></text>
<!-- Bottom rows (y >= 220 after horizontal white bar) -->
<rect x="130" y="220" width="40" height="10" class="fp8-block"/>
<rect x="170" y="220" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="220" width="40" height="10" class="fp8-block"/>
<rect x="290" y="220" width="40" height="10" class="fp8-block"/>
<rect x="330" y="220" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="230" width="40" height="10" class="fp8-block"/>
<rect x="210" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="230" width="40" height="10" class="fp8-block"/>
<rect x="130" y="240" width="40" height="10" class="fp8-block"/>
<rect x="170" y="240" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="240" width="40" height="10" class="fp8-block"/>
<rect x="290" y="240" width="40" height="10" class="fp8-block"/>
<rect x="330" y="240" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="250" width="40" height="10" class="fp8-block"/>
<rect x="210" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="250" width="40" height="10" class="fp8-block"/>
<!-- Main outline -->
<rect x="130.0" y="140.0" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Scaling factors tensor - 3+2 columns of 10px squares -->
<g id="scale-factors">
<!-- Orange background -->
<rect x="215" y="285" width="70" height="120" fill="#FFA500"/>
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="245" y="285" width="20" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="215" y="335" width="70" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Grid lines showing 10x10 squares (3 left + 2 right columns) -->
<!-- Vertical lines every 10px (skipping white space) -->
<!-- Left 3 columns -->
<line x1="225" y1="285" x2="225" y2="335" class="grid-line" stroke-width="1"/>
<line x1="235" y1="285" x2="235" y2="335" class="grid-line" stroke-width="1"/>
<line x1="245" y1="285" x2="245" y2="335" class="grid-line" stroke-width="1"/>
<!-- Right 2 columns -->
<line x1="265" y1="285" x2="265" y2="335" class="grid-line" stroke-width="1"/>
<line x1="275" y1="285" x2="275" y2="335" class="grid-line" stroke-width="1"/>
<line x1="285" y1="285" x2="285" y2="335" class="grid-line" stroke-width="1"/>
<!-- Bottom sections -->
<line x1="225" y1="365" x2="225" y2="405" class="grid-line" stroke-width="1"/>
<line x1="235" y1="365" x2="235" y2="405" class="grid-line" stroke-width="1"/>
<line x1="245" y1="365" x2="245" y2="405" class="grid-line" stroke-width="1"/>
<line x1="265" y1="365" x2="265" y2="405" class="grid-line" stroke-width="1"/>
<line x1="275" y1="365" x2="275" y2="405" class="grid-line" stroke-width="1"/>
<line x1="285" y1="365" x2="285" y2="405" class="grid-line" stroke-width="1"/>
<!-- Horizontal lines every 10px -->
<line x1="215" y1="295" x2="245" y2="295" class="grid-line" stroke-width="1"/>
<line x1="265" y1="295" x2="285" y2="295" class="grid-line" stroke-width="1"/>
<line x1="215" y1="305" x2="245" y2="305" class="grid-line" stroke-width="1"/>
<line x1="265" y1="305" x2="285" y2="305" class="grid-line" stroke-width="1"/>
<line x1="215" y1="315" x2="245" y2="315" class="grid-line" stroke-width="1"/>
<line x1="265" y1="315" x2="285" y2="315" class="grid-line" stroke-width="1"/>
<line x1="215" y1="325" x2="245" y2="325" class="grid-line" stroke-width="1"/>
<line x1="265" y1="325" x2="285" y2="325" class="grid-line" stroke-width="1"/>
<!-- Top bottom boundaries -->
<line x1="215" y1="335" x2="245" y2="335" class="grid-line" stroke-width="1"/>
<line x1="265" y1="335" x2="285" y2="335" class="grid-line" stroke-width="1"/>
<line x1="215" y1="365" x2="245" y2="365" class="grid-line" stroke-width="1"/>
<line x1="265" y1="365" x2="285" y2="365" class="grid-line" stroke-width="1"/>
<line x1="215" y1="375" x2="245" y2="375" class="grid-line" stroke-width="1"/>
<line x1="265" y1="375" x2="285" y2="375" class="grid-line" stroke-width="1"/>
<line x1="215" y1="385" x2="245" y2="385" class="grid-line" stroke-width="1"/>
<line x1="265" y1="385" x2="285" y2="385" class="grid-line" stroke-width="1"/>
<line x1="215" y1="395" x2="245" y2="395" class="grid-line" stroke-width="1"/>
<line x1="265" y1="395" x2="285" y2="395" class="grid-line" stroke-width="1"/>
<!-- Bottom boundaries -->
<line x1="215" y1="405" x2="245" y2="405" class="grid-line" stroke-width="1"/>
<line x1="265" y1="405" x2="285" y2="405" class="grid-line" stroke-width="1"/>
<!-- Main outline -->
<rect x="215" y="285" width="70" height="120" fill="none" stroke="#444" stroke-width="2"/>
<!-- Three dots -->
<text x="255" y="312.5" class="dots-text" style="font-size: 14px;"></text>
<text x="255" y="387.5" class="dots-text" style="font-size: 14px;"></text>
<text x="230" y="350" class="dots-text" style="font-size: 14px;"></text>
<text x="275" y="350" class="dots-text" style="font-size: 14px;"></text>
<text x="255" y="350" class="dots-text" style="font-size: 14px;" transform="rotate(45 255 350)"></text>
</g>
<text x="250" y="430" class="small-text" text-anchor="middle">E4M3 scaling factors (one per 16 elements)</text>
<!-- Global Scaling Factor -->
<g id="global-scale" transform="translate(350, 320)">
<rect x="10" y="10" width="20" height="20" class="global-scale"/>
<text x="20" y="60" class="small-text" text-anchor="middle">Global Scale (FP32)</text>
<text x="20" y="75" class="small-text" text-anchor="middle">(one per tensor)</text>
<text x="-20" y="25" class="dots-text" style="font-size: 20px;">+</text>
</g>
</g>
</svg>
\ No newline at end of file
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 650 520" width="100%" style="max-width: 650px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 18px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 12px sans-serif; fill: #555; text-anchor: middle; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
.fp8-block { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.fp8-block-alt { fill: #5F9FCC; stroke: #555; stroke-width: 1.5; }
.scale-block { fill: #FFA500; stroke: #555; stroke-width: 1.5; }
.global-scale { fill: #FF6B00; stroke: #444; stroke-width: 2; }
</style>
</defs>
<!-- ROWWISE SECTION -->
<text x="325" y="30" class="title">Rowwise (1×16 blocks)</text>
<!-- Rowwise Data [A, B] - 240x120 -->
<g id="rowwise-tensor">
<text x="160" y="55" class="small-text">Data [A, B]</text>
<!-- Top-left -->
<rect x="40" y="70" width="40" height="10" class="fp8-block"/>
<rect x="80" y="70" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="80" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="80" width="40" height="10" class="fp8-block"/>
<rect x="40" y="90" width="40" height="10" class="fp8-block"/>
<rect x="80" y="90" width="40" height="10" class="fp8-block-alt"/>
<!-- Top-right -->
<rect x="200" y="70" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="70" width="40" height="10" class="fp8-block"/>
<rect x="200" y="80" width="40" height="10" class="fp8-block"/>
<rect x="240" y="80" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="90" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="90" width="40" height="10" class="fp8-block"/>
<!-- Bottom-left -->
<rect x="40" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="160" width="40" height="10" class="fp8-block"/>
<rect x="40" y="170" width="40" height="10" class="fp8-block"/>
<rect x="80" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="180" width="40" height="10" class="fp8-block"/>
<!-- Bottom-right -->
<rect x="200" y="160" width="40" height="10" class="fp8-block"/>
<rect x="240" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="170" width="40" height="10" class="fp8-block"/>
<rect x="200" y="180" width="40" height="10" class="fp8-block"/>
<rect x="240" y="180" width="40" height="10" class="fp8-block-alt"/>
<text x="160" y="87" class="dots-text"></text>
<text x="160" y="177" class="dots-text"></text>
<text x="80" y="135" class="dots-text"></text>
<text x="240" y="135" class="dots-text"></text>
<rect x="40" y="70" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Rowwise s_block [A, B/16] - 90x120 -->
<g id="rowwise-scales">
<text x="425" y="55" class="small-text">s_block [A, B/16]</text>
<!-- Top-left -->
<rect x="380" y="70" width="10" height="10" class="scale-block"/>
<rect x="390" y="70" width="10" height="10" class="scale-block"/>
<rect x="380" y="80" width="10" height="10" class="scale-block"/>
<rect x="390" y="80" width="10" height="10" class="scale-block"/>
<rect x="380" y="90" width="10" height="10" class="scale-block"/>
<rect x="390" y="90" width="10" height="10" class="scale-block"/>
<!-- Top-right -->
<rect x="450" y="70" width="10" height="10" class="scale-block"/>
<rect x="460" y="70" width="10" height="10" class="scale-block"/>
<rect x="450" y="80" width="10" height="10" class="scale-block"/>
<rect x="460" y="80" width="10" height="10" class="scale-block"/>
<rect x="450" y="90" width="10" height="10" class="scale-block"/>
<rect x="460" y="90" width="10" height="10" class="scale-block"/>
<!-- Bottom-left -->
<rect x="380" y="160" width="10" height="10" class="scale-block"/>
<rect x="390" y="160" width="10" height="10" class="scale-block"/>
<rect x="380" y="170" width="10" height="10" class="scale-block"/>
<rect x="390" y="170" width="10" height="10" class="scale-block"/>
<rect x="380" y="180" width="10" height="10" class="scale-block"/>
<rect x="390" y="180" width="10" height="10" class="scale-block"/>
<!-- Bottom-right -->
<rect x="450" y="160" width="10" height="10" class="scale-block"/>
<rect x="460" y="160" width="10" height="10" class="scale-block"/>
<rect x="450" y="170" width="10" height="10" class="scale-block"/>
<rect x="460" y="170" width="10" height="10" class="scale-block"/>
<rect x="450" y="180" width="10" height="10" class="scale-block"/>
<rect x="460" y="180" width="10" height="10" class="scale-block"/>
<text x="425" y="87" class="dots-text"></text>
<text x="425" y="177" class="dots-text"></text>
<text x="390" y="135" class="dots-text"></text>
<text x="460" y="135" class="dots-text"></text>
<rect x="380" y="70" width="90" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Rowwise s_global -->
<g id="rowwise-global">
<text x="545" y="55" class="small-text">s_global</text>
<rect x="535" y="120" width="20" height="20" class="global-scale"/>
</g>
<!-- COLUMNWISE SECTION -->
<text x="325" y="230" class="title">Columnwise (16×1 blocks) — transposed storage</text>
<!-- Columnwise Data [B, A] - 120x240 (transposed) -->
<g id="colwise-tensor">
<text x="100" y="255" class="small-text">Data [B, A]</text>
<!-- Top-left blocks -->
<rect x="40" y="270" width="40" height="10" class="fp8-block"/>
<rect x="40" y="280" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="290" width="40" height="10" class="fp8-block"/>
<rect x="40" y="300" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="310" width="40" height="10" class="fp8-block"/>
<rect x="40" y="320" width="40" height="10" class="fp8-block-alt"/>
<!-- Top-right blocks -->
<rect x="120" y="270" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="280" width="40" height="10" class="fp8-block"/>
<rect x="120" y="290" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="300" width="40" height="10" class="fp8-block"/>
<rect x="120" y="310" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="320" width="40" height="10" class="fp8-block"/>
<!-- Bottom-left blocks (stick to bottom: box ends at y=510) -->
<rect x="40" y="450" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="460" width="40" height="10" class="fp8-block"/>
<rect x="40" y="470" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="480" width="40" height="10" class="fp8-block"/>
<rect x="40" y="490" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="500" width="40" height="10" class="fp8-block"/>
<!-- Bottom-right blocks -->
<rect x="120" y="450" width="40" height="10" class="fp8-block"/>
<rect x="120" y="460" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="470" width="40" height="10" class="fp8-block"/>
<rect x="120" y="480" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="490" width="40" height="10" class="fp8-block"/>
<rect x="120" y="500" width="40" height="10" class="fp8-block-alt"/>
<text x="100" y="307" class="dots-text"></text>
<text x="100" y="487" class="dots-text"></text>
<text x="60" y="395" class="dots-text"></text>
<text x="140" y="395" class="dots-text"></text>
<rect x="40" y="270" width="120" height="240" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Columnwise s_block [B, A/16] - 45x240, aligned with rowwise scales at x=380 -->
<g id="colwise-scales">
<text x="402" y="255" class="small-text">s_block [B, A/16]</text>
<!-- Top-left -->
<rect x="380" y="270" width="10" height="10" class="scale-block"/>
<rect x="380" y="280" width="10" height="10" class="scale-block"/>
<rect x="380" y="290" width="10" height="10" class="scale-block"/>
<rect x="380" y="300" width="10" height="10" class="scale-block"/>
<rect x="380" y="310" width="10" height="10" class="scale-block"/>
<rect x="380" y="320" width="10" height="10" class="scale-block"/>
<!-- Top-right -->
<rect x="415" y="270" width="10" height="10" class="scale-block"/>
<rect x="415" y="280" width="10" height="10" class="scale-block"/>
<rect x="415" y="290" width="10" height="10" class="scale-block"/>
<rect x="415" y="300" width="10" height="10" class="scale-block"/>
<rect x="415" y="310" width="10" height="10" class="scale-block"/>
<rect x="415" y="320" width="10" height="10" class="scale-block"/>
<!-- Bottom-left -->
<rect x="380" y="450" width="10" height="10" class="scale-block"/>
<rect x="380" y="460" width="10" height="10" class="scale-block"/>
<rect x="380" y="470" width="10" height="10" class="scale-block"/>
<rect x="380" y="480" width="10" height="10" class="scale-block"/>
<rect x="380" y="490" width="10" height="10" class="scale-block"/>
<rect x="380" y="500" width="10" height="10" class="scale-block"/>
<!-- Bottom-right -->
<rect x="415" y="450" width="10" height="10" class="scale-block"/>
<rect x="415" y="460" width="10" height="10" class="scale-block"/>
<rect x="415" y="470" width="10" height="10" class="scale-block"/>
<rect x="415" y="480" width="10" height="10" class="scale-block"/>
<rect x="415" y="490" width="10" height="10" class="scale-block"/>
<rect x="415" y="500" width="10" height="10" class="scale-block"/>
<text x="402" y="307" class="dots-text"></text>
<text x="402" y="487" class="dots-text"></text>
<text x="387" y="395" class="dots-text"></text>
<text x="420" y="395" class="dots-text"></text>
<rect x="380" y="270" width="45" height="240" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Columnwise s_global, aligned with rowwise at x=535 -->
<g id="colwise-global">
<text x="545" y="255" class="small-text">s_global</text>
<rect x="535" y="380" width="20" height="20" class="global-scale"/>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 900 220">
<defs>
<style>
.sign-bit { fill: #9db4d0; stroke: #333; stroke-width: 1; }
.exponent-bit { fill: #d9a066; stroke: #333; stroke-width: 1; }
.mantissa-bit { fill: #a8d99c; stroke: #333; stroke-width: 1; }
.bit-text { fill: #000; text-anchor: middle; dominant-baseline: middle; font-size: 16px; }
.header-text { fill: #555; font-weight: normal; text-anchor: middle; font-size: 18px; }
.value-text { fill: #333; font-size: 18px; }
.format-label { fill: #333; font-weight: bold; text-anchor: middle; dominant-baseline: middle; font-size: 20px; }
</style>
</defs>
<!-- FP8 E4M3 Format (8 bits: 1 + 4 + 3) -->
<text x="60" y="60" class="format-label">FP8 E4M3</text>
<!-- Sign bit (1) -->
<rect x="140" y="45" width="18" height="30" class="sign-bit"/>
<text x="149" y="60" class="bit-text">0</text>
<!-- Exponent bits (4) -->
<rect x="163" y="45" width="18" height="30" class="exponent-bit"/>
<text x="172" y="60" class="bit-text">1</text>
<rect x="186" y="45" width="18" height="30" class="exponent-bit"/>
<text x="195" y="60" class="bit-text">0</text>
<rect x="209" y="45" width="18" height="30" class="exponent-bit"/>
<text x="218" y="60" class="bit-text">0</text>
<rect x="232" y="45" width="18" height="30" class="exponent-bit"/>
<text x="241" y="60" class="bit-text">0</text>
<!-- Mantissa bits (3) -->
<rect x="255" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="264" y="60" class="bit-text">1</text>
<rect x="278" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="287" y="60" class="bit-text">1</text>
<rect x="301" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="310" y="60" class="bit-text">1</text>
<text x="355" y="60" class="value-text">(1 sign, 4 exp, 3 mantissa)</text>
<!-- FP8 E5M2 Format (8 bits: 1 + 5 + 2) -->
<text x="60" y="120" class="format-label">FP8 E5M2</text>
<!-- Sign bit (1) -->
<rect x="140" y="105" width="18" height="30" class="sign-bit"/>
<text x="149" y="120" class="bit-text">0</text>
<!-- Exponent bits (5) -->
<rect x="163" y="105" width="18" height="30" class="exponent-bit"/>
<text x="172" y="120" class="bit-text">1</text>
<rect x="186" y="105" width="18" height="30" class="exponent-bit"/>
<text x="195" y="120" class="bit-text">0</text>
<rect x="209" y="105" width="18" height="30" class="exponent-bit"/>
<text x="218" y="120" class="bit-text">0</text>
<rect x="232" y="105" width="18" height="30" class="exponent-bit"/>
<text x="241" y="120" class="bit-text">0</text>
<rect x="255" y="105" width="18" height="30" class="exponent-bit"/>
<text x="264" y="120" class="bit-text">0</text>
<!-- Mantissa bits (2) -->
<rect x="278" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="287" y="120" class="bit-text">1</text>
<rect x="301" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="310" y="120" class="bit-text">1</text>
<text x="355" y="120" class="value-text">(1 sign, 5 exp, 2 mantissa)</text>
<!-- NVFP4 E2M1 Format (4 bits: 1 + 2 + 1) -->
<text x="60" y="180" class="format-label">NVFP4</text>
<!-- Sign bit (1) -->
<rect x="140" y="165" width="18" height="30" class="sign-bit"/>
<text x="149" y="180" class="bit-text">0</text>
<!-- Exponent bits (2) -->
<rect x="163" y="165" width="18" height="30" class="exponent-bit"/>
<text x="172" y="180" class="bit-text">1</text>
<rect x="186" y="165" width="18" height="30" class="exponent-bit"/>
<text x="195" y="180" class="bit-text">0</text>
<!-- Mantissa bits (1) -->
<rect x="209" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="218" y="180" class="bit-text">1</text>
<text x="355" y="180" class="value-text">(1 sign, 2 exp, 1 mantissa)</text>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 900 340">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
/* Diagram-specific styles */
.input { fill: #e3f2fd; stroke: #1976d2; stroke-width: 2; }
.grad { fill: #fce4ec; stroke: #d81b60; stroke-width: 2; }
.rht { fill: #fff3e0; stroke: #f57c00; stroke-width: 2; }
.output { fill: #e8f5e9; stroke: #388e3c; stroke-width: 2; }
.divider { stroke: #bdbdbd; stroke-width: 2; stroke-dasharray: 6,4; }
/* Arrow override */
.arrow { marker-end: url(#arrowhead); }
</style>
<!-- Arrow marker -->
<marker id="arrowhead" markerWidth="8" markerHeight="8" refX="7" refY="3" orient="auto">
<polygon points="0 0, 8 3, 0 6" fill="#616161" />
</marker>
</defs>
<!-- Title -->
<text x="450" y="30" class="title">Random Hadamard Transform for WGRAD GEMM</text>
<!-- Divider -->
<line x1="450" y1="50" x2="450" y2="320" class="divider"/>
<!-- ═══════════════ LEFT SIDE: Without RHT ═══════════════ -->
<g id="without-rht">
<text x="225" y="70" class="section-title">Without RHT</text>
<!-- Top row: Activations → Quantize → GEMM -->
<!-- Activations -->
<rect x="40" y="100" width="90" height="45" rx="6" class="input"/>
<text x="85" y="127" class="text">Activations</text>
<!-- Arrow -->
<path d="M 130 122 L 175 122" class="arrow"/>
<!-- Quantize -->
<rect x="175" y="100" width="80" height="45" rx="6" class="quantize"/>
<text x="215" y="127" class="text">Quantize</text>
<!-- Arrow to GEMM -->
<path d="M 255 122 L 300 122" class="arrow"/>
<!-- GEMM -->
<rect x="300" y="80" width="85" height="90" rx="6" class="gemm"/>
<text x="342" y="118" class="text">WGRAD</text>
<text x="342" y="138" class="text">GEMM</text>
<!-- Bottom row: Output Grad → Quantize → GEMM -->
<!-- Output Grad -->
<rect x="40" y="170" width="90" height="45" rx="6" class="grad"/>
<text x="85" y="197" class="text">Output Grad</text>
<!-- Arrow -->
<path d="M 130 192 L 175 192" class="arrow"/>
<!-- Quantize 2 -->
<rect x="175" y="170" width="80" height="45" rx="6" class="quantize"/>
<text x="215" y="197" class="text">Quantize</text>
<!-- Arrow to GEMM (diagonal) -->
<path d="M 255 192 L 300 155" class="arrow"/>
<!-- Arrow from GEMM to output -->
<path d="M 342 170 L 342 245" class="arrow"/>
<!-- Weight Grad -->
<rect x="300" y="245" width="85" height="45" rx="6" class="output"/>
<text x="342" y="272" class="text">Weight Grad</text>
</g>
<!-- ═══════════════ RIGHT SIDE: With RHT ═══════════════ -->
<g id="with-rht">
<text x="675" y="70" class="section-title">With RHT</text>
<!-- Top row: Activations → RHT → Quantize → GEMM -->
<!-- Activations -->
<rect x="460" y="100" width="90" height="45" rx="6" class="input"/>
<text x="505" y="127" class="text">Activations</text>
<!-- Arrow -->
<path d="M 550 122 L 575 122" class="arrow"/>
<!-- RHT -->
<rect x="575" y="100" width="50" height="45" rx="6" class="rht"/>
<text x="600" y="127" class="text">RHT</text>
<!-- Arrow -->
<path d="M 625 122 L 650 122" class="arrow"/>
<!-- Quantize -->
<rect x="650" y="100" width="80" height="45" rx="6" class="quantize"/>
<text x="690" y="127" class="text">Quantize</text>
<!-- Arrow to GEMM -->
<path d="M 730 122 L 775 122" class="arrow"/>
<!-- GEMM -->
<rect x="775" y="80" width="85" height="90" rx="6" class="gemm"/>
<text x="817" y="118" class="text">WGRAD</text>
<text x="817" y="138" class="text">GEMM</text>
<!-- Bottom row: Output Grad → RHT → Quantize → GEMM -->
<!-- Output Grad -->
<rect x="460" y="170" width="90" height="45" rx="6" class="grad"/>
<text x="505" y="197" class="text">Output Grad</text>
<!-- Arrow -->
<path d="M 550 192 L 575 192" class="arrow"/>
<!-- RHT 2 -->
<rect x="575" y="170" width="50" height="45" rx="6" class="rht"/>
<text x="600" y="197" class="text">RHT</text>
<!-- Arrow -->
<path d="M 625 192 L 650 192" class="arrow"/>
<!-- Quantize 2 -->
<rect x="650" y="170" width="80" height="45" rx="6" class="quantize"/>
<text x="690" y="197" class="text">Quantize</text>
<!-- Arrow to GEMM (diagonal) -->
<path d="M 730 192 L 775 155" class="arrow"/>
<!-- Arrow from GEMM to output -->
<path d="M 817 170 L 817 245" class="arrow"/>
<!-- Weight Grad -->
<rect x="775" y="245" width="85" height="45" rx="6" class="output"/>
<text x="817" y="272" class="text">Weight Grad</text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1000 400">
<defs>
<style>
.axis { stroke: #333; stroke-width: 3; stroke-linecap: round; }
.tick { stroke: #333; stroke-width: 2; }
.tick-label { font-family: sans-serif; font-size: 22px; text-anchor: middle; fill: #333; font-weight: bold; }
.title { font-family: sans-serif; font-size: 24px; font-weight: bold; text-anchor: middle; fill: #333; }
.sub-label { font-family: sans-serif; font-size: 16px; text-anchor: middle; fill: #555; }
.value-point { fill: #e74c3c; stroke: #333; stroke-width: 2; }
.value-label { font-family: sans-serif; font-size: 22px; text-anchor: middle; fill: #e74c3c; font-weight: bold; }
.divider { stroke: #ccc; stroke-width: 2; stroke-dasharray: 5,5; }
.bar-bg { fill: #eee; stroke: #555; stroke-width: 1.5; }
.bar-fill-blue { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.bar-fill-green { fill: #a8d99c; stroke: #555; stroke-width: 1.5; }
.percentage-text { font-family: sans-serif; font-weight: bold; font-size: 20px; }
</style>
</defs>
<!-- Divider -->
<line x1="500" y1="20" x2="500" y2="380" class="divider" />
<!-- LEFT SIDE: Deterministic Rounding -->
<g transform="translate(0,0)">
<text x="250" y="40" class="title">Round to Nearest</text>
<!-- Axis -->
<line x1="50" y1="150" x2="450" y2="150" class="axis" />
<!-- v1 -->
<line x1="100" y1="140" x2="100" y2="160" class="tick" />
<text x="100" y="185" class="tick-label">v₁</text>
<!-- v2 -->
<line x1="400" y1="140" x2="400" y2="160" class="tick" />
<text x="400" y="185" class="tick-label">v₂</text>
<!-- x (at 40% distance) -->
<circle cx="220" cy="150" r="8" class="value-point" />
<text x="220" y="120" class="value-label">x</text>
<!-- Visuals for deterministic rounding: Bars -->
<g transform="translate(50, 230)">
<!-- Bar for v1 (100%) -->
<text x="50" y="-10" class="sub-label">Round to v₁</text>
<rect x="10" y="0" width="80" height="100" rx="4" class="bar-bg" />
<rect x="10" y="0" width="80" height="100" rx="4" class="bar-fill-green" />
<text x="50" y="56" fill="#000" text-anchor="middle" class="percentage-text">100%</text>
<!-- Bar for v2 (0%) -->
<text x="350" y="-10" class="sub-label">Round to v₂</text>
<rect x="310" y="0" width="80" height="100" rx="4" class="bar-bg" />
<!-- 0% filled, so just bg visible -->
<text x="350" y="56" fill="#666" text-anchor="middle" class="percentage-text">0%</text>
</g>
</g>
<!-- RIGHT SIDE: Stochastic -->
<g transform="translate(500,0)">
<text x="250" y="40" class="title">Stochastic Rounding</text>
<!-- Axis -->
<line x1="50" y1="150" x2="450" y2="150" class="axis" />
<!-- v1 -->
<line x1="100" y1="140" x2="100" y2="160" class="tick" />
<text x="100" y="185" class="tick-label">v₁</text>
<!-- v2 -->
<line x1="400" y1="140" x2="400" y2="160" class="tick" />
<text x="400" y="185" class="tick-label">v₂</text>
<!-- x (at 40% distance) -->
<circle cx="220" cy="150" r="8" class="value-point" />
<text x="220" y="120" class="value-label">x</text>
<!-- Visuals for Stochastic: Bars -->
<g transform="translate(50, 230)">
<!-- Bar for v1 -->
<text x="50" y="-10" class="sub-label">Round to v₁</text>
<rect x="10" y="0" width="80" height="100" rx="4" class="bar-bg" />
<rect x="10" y="40" width="80" height="60" rx="4" class="bar-fill-blue" />
<text x="50" y="80" fill="#000" text-anchor="middle" class="percentage-text">60%</text>
<!-- Bar for v2 -->
<text x="350" y="-10" class="sub-label">Round to v₂</text>
<rect x="310" y="0" width="80" height="100" rx="4" class="bar-bg" />
<rect x="310" y="60" width="80" height="40" rx="4" class="bar-fill-blue" />
<text x="350" y="90" fill="#000" text-anchor="middle" class="percentage-text">40%</text>
</g>
</g>
</svg>
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Check for Blackwell or newer GPU
from transformer_engine.jax.quantize import get_device_compute_capability
assert (
get_device_compute_capability() >= 100
), f"NVFP4 requires SM100 (Blackwell) or later, got SM{get_device_compute_capability()}"
# START_NVFP4_EXAMPLE
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import DenseGeneral
from transformer_engine.common.recipe import NVFP4BlockScaling
# Define NVFP4 recipe
# 2D weight quantization and RHT are enabled by default
recipe = NVFP4BlockScaling()
# To disable features, use:
# recipe = NVFP4BlockScaling(disable_rht=True, disable_2d_quantization=True)
with te.autocast(enabled=True, recipe=recipe):
# Initialize layer and data
layer = DenseGeneral(features=1024)
key, sr_key = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
# NVFP4 requires sr_rng for stochastic rounding
rngs = {"sr_rng": sr_key}
var_collect = layer.init({"params": key, "sr_rng": sr_key}, x)
# Forward and backward pass
def loss_fn(var_collect):
output = layer.apply(var_collect, x, rngs=rngs)
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
# END_NVFP4_EXAMPLE
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
NVFP4
===================================
NVFP4 is the first 4-bit recipe introduced in Transformer Engine –
please refer to the `NVFP4 paper <https://arxiv.org/abs/2509.25149>`__ for more details.
It is a more complex recipe than the previous ones – apart from the new data format,
it introduces multiple features which help training stability.
Data Format
----------------------
The NVFP4 datatype consists of 1 sign bit, 2 exponent bits, and 1 mantissa bit (E2M1).
It can represent values of magnitude up to +/- 6.
NVFP4 uses a hierarchical block scaling approach where multiple scaling factors are combined to recover the high precision value.
.. raw:: html
:file: img/nvfp4_vs_fp8.svg
*Figure 1. Bit layout comparison between standard FP8 formats (E4M3 and E5M2) and NVFP4 (E2M1).*
The representation of an NVFP4 tensor element ``x`` is given by:
.. code-block:: python
x = x_e2m1 * s_block * s_global
where
* ``x_e2m1`` is the 4-bit value,
* ``s_block`` is a local **FP8 E4M3** scaling factor shared by a block of 16 consecutive elements,
* ``s_global`` is a global **FP32** scaling factor applied to the entire tensor.
**Scaling Factor Computation**
The scaling factors are computed as follows:
1. Global scaling factor (``s_global``):
.. code-block:: python
s_global = global_amax / (fp8_max * fp4_max)
# where:
# - global_amax: maximum absolute value across the entire tensor
# - fp8_max: maximum representable value in FP8 E4M3 (448.0)
# - fp4_max: maximum representable value in NVFP4 E2M1 (6.0)
2. Block scaling factor (``s_block``):
.. code-block:: python
s_block = (block_amax / fp4_max) / s_global
# where:
# - block_amax: maximum absolute value within the block
# - fp4_max: maximum representable value in NVFP4 E2M1 (6.0)
# - s_block is stored in FP8 E4M3 format
.. raw:: html
:file: img/nvfp4_hierarchical_scaling.svg
*Figure 2. NVFP4 hierarchical scaling structure showing the combination of block-level and global scaling factors.*
This hierarchical structure uses fine-grained block scaling to handle the tensor's dynamic range,
while the FP4 values represent the block-level dynamic range. The global scaling factor
aligns values to the representable range of the E4M3 × E2M1 combination.
**2D weight scaling**
NVFP4 can be:
* 1 dimensional - each block of 16 consecutive elements shares a scaling factor,
* 2 dimensional - each block of 16x16 elements shares a scaling factor.
By default, NVFP4 uses 2D scaling for weights and 1D scaling for activations and gradients.
Set ``disable_2d_quantization=True`` in the recipe configuration to force 1D scaling for weights as well (activations and gradients always use 1D).
The motivation for using 2D scaling for weights is to ensure that rowwise and columnwise
quantized tensors are numerically equivalent.
Please refer to the `NVFP4 paper <https://arxiv.org/abs/2509.25149>`__ for more details.
Stochastic Rounding
-------------------
Stochastic rounding is applied when casting scaled values to NVFP4 format. Instead of deterministic rounding
(always rounding to nearest even value), each scaled value is probabilistically rounded to one of the two
nearest representable NVFP4 values. The probability of rounding to a given value is inversely proportional to
the distance to that value, which ensures that the expected value of the quantized
tensor equals the original value, eliminating systematic quantization bias during training.
Stochastic rounding is hardware-accelerated using native GPU instructions introduced with the
Blackwell architecture.
.. raw:: html
:file: img/stochastic_rounding.svg
*Figure 3. Stochastic rounding illustration. Given a value* ``x`` *to be quantized, and the two nearest
representable NVFP4 values* ``v1`` *(lower) and* ``v2`` *(higher), deterministic rounding always
rounds to the nearest value, while stochastic rounding probabilistically rounds to either value.
If* ``x`` *is 40% of the way from* ``v1`` *to* ``v2``, *there is a 60% chance of rounding to* ``v1``
*and a 40% chance of rounding to* ``v2``.
Stochastic rounding is enabled only for gradients. It can be disabled by setting
``disable_stochastic_rounding=True`` in the recipe configuration.
Random Hadamard Transform
--------------------------
Random Hadamard Transform (RHT) applies an orthogonal rotation to the tensor **before quantization**,
smoothing outliers in the tensor distributions and making them easier to represent accurately in NVFP4.
RHT is applied to columnwise quantization of inputs and gradients, which are operands
for the **wgrad GEMM**. This GEMM is particularly sensitive
to quantization errors, hence the additional outlier smoothing.
RHT is supported only for BF16 inputs/gradients.
The transform is defined as:
.. math::
x' = x H
where :math:`H` is the RHT matrix defined below. The quantization scale factor is computed
from the rotated tensor :math:`x'`.
**Hadamard matrix**
The :math:`d \times d` Hadamard matrix has elements :math:`\pm 1` and satisfies :math:`H_d H_d^T = d I`.
When normalized by :math:`1/\sqrt{d}`, the matrix becomes orthogonal and can be applied
to both operands of a matrix multiplication:
.. math::
C = (AH)(H^T B) = AB
where the transforms cancel within the dot-product since :math:`H H^T = I`.
**Sign matrix**
In the RHT implementation, a :math:`d`-dimensional diagonal sign matrix :math:`S_d` is applied
together with the Hadamard matrix:
.. math::
H = \frac{1}{\sqrt{d}} S_d H_d
where diagonal entries of :math:`S_d` are :math:`\{-1, 1\}` and flip the signs of different rows of :math:`H_d`.
As described in the paper, a single random sign vector is shared across all linear layers throughout training.
In the implementation, this vector is fixed and the RHT matrix is computed once at initialization and cached.
**Tiled implementation**
The Hadamard transform is performed in a tiled approach along the last dimension of the tensor.
For an :math:`m \times k` tensor, the data is reshaped to :math:`(mk/d) \times d`
and multiplied by the :math:`d \times d` matrix :math:`H`. In this implementation, :math:`d = 16`.
.. raw:: html
:file: img/rht.svg
*Figure 4. WGRAD GEMM pipeline comparison: without RHT (left) and with RHT applied (right).*
Handling transposes
-------------------
Like :doc:`MXFP8 <../mxfp8/mxfp8>`, NVFP4 requires both rowwise and columnwise quantized tensors
for different GEMM operands. Unlike MXFP8 which supports multiple layouts (TN, NT, NN),
**NVFP4 GEMM only supports the TN layout**.
NVFP4 stores columnwise data and scaling factors in a **transposed layout**:
- **Rowwise**: data ``[A, B]`` with 1×16 horizontal blocks, ``scales`` shape ``[A, B/16]``
- **Columnwise**: data ``[B, A]`` (transposed) with 1×16 horizontal blocks, ``scales`` shape ``[B, A/16]``
Scale tensors are padded for hardware alignment: first dimension to a multiple of 128,
second dimension to a multiple of 4 (e.g. rowwise: ``[roundup(A, 128), roundup(B/16, 4)]``).
.. raw:: html
:file: img/nvfp4_row_col.svg
*Figure 5. NVFP4 rowwise vs columnwise quantization layout. Unlike MXFP8, columnwise scales are stored transposed.*
Distributed training
--------------------
**Amax reduction**
Block scaling factors (``s_block``) do not require synchronization between nodes,
as each scaling factor is local to its block of 16 elements.
However, the global scaling factor (``s_global``) requires amax synchronization for gathered tensors.
For tensors that are gathered (e.g., input and gradient in sequence parallelism),
amax reduction is performed before quantization.
If before synchronization there was ``amax_1`` on node 1,
``amax_2`` on node 2, etc., after synchronization there will be ``max(amax_1, amax_2, ...)`` on all nodes.
**Quantized all-gather**
NVFP4 all-gather is supported.
.. raw:: html
:file: img/nvfp4_all_gather.svg
*Figure 6. Quantization and all-gather flow for NVFP4 showing amax synchronization and hierarchical scaling.*
Examples
--------
Here's how to use NVFP4 recipe in PyTorch and JAX. The examples show how to configure features like 2D weight quantization and Random Hadamard Transform (RHT):
.. tabs::
.. tab:: PyTorch
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Requires SM100 (Blackwell) or later
</div>
.. literalinclude:: pytorch_nvfp4_example.py
:language: python
:start-after: # START_NVFP4_EXAMPLE
:end-before: # END_NVFP4_EXAMPLE
.. tab:: JAX
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Requires SM100 (Blackwell) or later
</div>
.. literalinclude:: jax_nvfp4_example.py
:language: python
:start-after: # START_NVFP4_EXAMPLE
:end-before: # END_NVFP4_EXAMPLE
Supported devices
-----------------
* **Training**: SM 10.0, SM 10.3
* **Inference**: SM 10.0+
----
Developer Notes
---------------
This section contains implementation details that may be useful for developers
but are not required for using NVFP4 in practice.
Swizzling scaling factors
^^^^^^^^^^^^^^^^^^^^^^^^^
NVFP4 requires swizzling of block scaling factors (``s_block``) before GEMM operations,
similar to :doc:`MXFP8 <../mxfp8/mxfp8>`. Key differences:
- Block size is 16 (vs 32 for MXFP8)
- Both rowwise and columnwise scaling factors are swizzled, but thanks to the transposed
columnwise layout, a single rowwise swizzle kernel handles both cases.
- Scaling factors are stored as FP8 E4M3 (vs E8M0 for MXFP8)
All-gather of columnwise tensors
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
All-gather of columnwise tensors is supported. To enable quantized all-gather,
all nodes must use the same ``s_global``, which is computed from the synchronized global amax.
This is automatically enabled for column-parallel and row-parallel linear layers.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Check for Blackwell or newer GPU
major, minor = torch.cuda.get_device_capability()
assert major >= 10, f"NVFP4 requires SM100 (Blackwell) or later, got SM{major}{minor}"
# START_NVFP4_EXAMPLE
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import NVFP4BlockScaling
# Define NVFP4 recipe
# 2D weight quantization and RHT are enabled by default
recipe = NVFP4BlockScaling()
# To disable features, use:
# recipe = NVFP4BlockScaling(disable_rht=True, disable_2d_quantization=True)
# Create a linear layer with bfloat16 parameters
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
# Forward and backward pass
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
loss = output.sum()
loss.backward()
# END_NVFP4_EXAMPLE
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
# START_FUSED_LAYERS
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import LayerNorm, DenseGeneral, LayerNormDenseGeneral
from transformer_engine.common.recipe import DelayedScaling
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
# Example 1: Separate LayerNorm and DenseGeneral layers
layer_norm = LayerNorm()
dense = DenseGeneral(features=1024)
# Initialize parameters
ln_params = layer_norm.init(key, x)
dense_params = dense.init(key, x)
# Two separate operations
normalized = layer_norm.apply(ln_params, x)
output_separate = dense.apply(dense_params, normalized)
# Example 2: Fused LayerNormDenseGeneral layer
fused_layer = LayerNormDenseGeneral(features=1024)
# Initialize and apply with FP8 autocast
recipe = DelayedScaling()
with te.autocast(enabled=True, recipe=recipe):
fused_params = fused_layer.init(key, x)
output_fused, _ = fused_layer.apply(fused_params, x) # Returns (output, ln_output)
# The fused layer is more efficient as it combines LayerNorm and quantization
# END_FUSED_LAYERS
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
cc = torch.cuda.get_device_capability()
assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)"
# START_FUSED_LAYERS
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
# Example 1: Separate LayerNorm and Linear layers
layer_norm = te.LayerNorm(1024)
linear = te.Linear(1024, 1024)
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
# Two separate operations: LayerNorm produces FP32, then Linear quantizes it
normalized = layer_norm(inp)
output_separate = linear(normalized)
# Example 2: Fused LayerNormLinear layer
fused_layer = te.LayerNormLinear(1024, 1024, params_dtype=torch.bfloat16)
# Single operation: LayerNorm output is directly quantized
recipe = DelayedScaling()
with te.autocast(enabled=True, recipe=recipe):
output_fused = fused_layer(inp)
# The fused layer is more efficient as it avoids redundant quantization
# END_FUSED_LAYERS
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 850 640" width="850" height="640">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
/* Diagram-specific overrides */
.hp { fill: #f3e5f5; stroke: #7b1fa2; stroke-width: 2.5; }
.gemm { fill: #ffe8cc; stroke: #f57c00; stroke-width: 3; }
.quantize { fill: #c8e6c9; stroke: #43a047; stroke-width: 2.5; }
.text { font-weight: 500; }
.small-text { fill: #616161; }
.title { font-weight: 700; fill: #1a1a1a; }
.arrow { stroke: #424242; stroke-width: 2.5; fill: none; marker-end: url(#arrowhead); }
</style>
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#424242" />
</marker>
</defs>
<!-- Title -->
<text x="425" y="30" class="title">LayerNorm + Linear: Separate vs Fused</text>
<!-- Divider Line -->
<line x1="425" y1="50" x2="425" y2="620" stroke="#bdbdbd" stroke-width="2.5" stroke-dasharray="5,5"/>
<!-- LEFT: Separate Layers -->
<text x="212" y="65" class="section-title">Scenario 1: Separate Layers</text>
<!-- Input -->
<rect x="137" y="95" width="150" height="55" class="hp" rx="8"/>
<text x="212" y="128" class="text">Input</text>
<!-- Arrow -->
<path d="M 212 150 L 212 180" class="arrow"/>
<!-- LayerNorm -->
<rect x="137" y="180" width="150" height="55" class="layernorm" rx="8"/>
<text x="212" y="213" class="text" style="font-weight: 600;">LayerNorm</text>
<!-- Arrow -->
<path d="M 212 235 L 212 260" class="arrow"/>
<!-- Output FP32 -->
<rect x="147" y="260" width="130" height="50" class="hp" rx="8"/>
<text x="212" y="290" class="text">Output</text>
<!-- Arrow -->
<path d="M 212 310 L 212 330" class="arrow"/>
<!-- Linear wrapper -->
<rect x="122" y="330" width="180" height="190" class="gemm" rx="10"/>
<text x="212" y="355" class="text" style="font-weight: 600;">Linear</text>
<!-- Quantize inside -->
<rect x="142" y="368" width="140" height="30" class="quantize" rx="8"/>
<text x="212" y="388" class="text">Quantize</text>
<!-- Arrow -->
<path d="M 212 398 L 212 418" class="arrow"/>
<!-- FP8 tensor -->
<rect x="147" y="418" width="130" height="35" class="fp8" rx="8"/>
<text x="212" y="441" class="text">FP8 tensor</text>
<!-- Arrow -->
<path d="M 212 453 L 212 473" class="arrow"/>
<!-- Simplified computation inside -->
<rect x="142" y="473" width="140" height="30" class="computation" rx="8"/>
<text x="212" y="493" class="title" style="font-weight: 600;">...</text>
<!-- Arrow out of Linear -->
<path d="M 212 520 L 212 545" class="arrow"/>
<!-- Output -->
<rect x="137" y="545" width="150" height="50" class="hp" rx="8"/>
<text x="212" y="575" class="text">Output</text>
<!-- RIGHT: Fused Layer -->
<text x="637" y="65" class="section-title">Scenario 2: Fused Layer</text>
<!-- Input -->
<rect x="562" y="95" width="150" height="55" class="hp" rx="8"/>
<text x="637" y="128" class="text">Input</text>
<!-- Arrow -->
<path d="M 637 150 L 637 185" class="arrow"/>
<!-- LayerNormLinear Fused -->
<rect x="517" y="185" width="240" height="220" class="fused" rx="10"/>
<text x="637" y="212" class="text" style="font-weight: 600;">LayerNormLinear</text>
<!-- Inside fused block -->
<!-- LayerNorm -->
<rect x="537" y="235" width="200" height="40" class="layernorm" rx="8"/>
<text x="637" y="260" class="text">LayerNorm + Quantize</text>
<!-- Arrow -->
<path d="M 637 275 L 637 295" class="arrow"/>
<!-- FP8 tensor -->
<rect x="562" y="295" width="150" height="35" class="fp8" rx="8"/>
<text x="637" y="318" class="text">FP8 tensor</text>
<!-- Arrow -->
<path d="M 637 330 L 637 350" class="arrow"/>
<!-- Simplified computation -->
<rect x="567" y="350" width="140" height="35" class="computation" rx="8"/>
<text x="637" y="373" class="title" style="font-weight: 600;">...</text>
<!-- Arrow out of fused block -->
<path d="M 637 405 L 637 545" class="arrow"/>
<!-- Output -->
<rect x="562" y="545" width="150" height="50" class="hp" rx="8"/>
<text x="637" y="575" class="text">Output</text>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 340 280" width="100%" style="max-width: 340px;">
<defs>
<style>
.title { font: bold 12px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 10px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 9px sans-serif; fill: #555; }
.operator { font: bold 14px sans-serif; fill: #333; }
.equals { font: bold 14px sans-serif; fill: #333; }
/* Matrix colors */
.matrix-cell { fill: #bbdefb; stroke: #1565c0; stroke-width: 1; }
.highlight-row { fill: #c8e6c9; stroke: #1565c0; stroke-width: 1; }
.highlight-col { fill: #fff3e0; stroke: #1565c0; stroke-width: 1; }
.highlight-result { fill: #ffcc80; stroke: #1565c0; stroke-width: 1; }
</style>
</defs>
<!-- FIRST SCENARIO: A × B -->
<!-- Title -->
<text x="170" y="16" class="title">NN GEMM</text>
<!-- Matrix A (left) - centered -->
<g id="matrix-a1" transform="translate(40, 52)">
<text x="28" y="-14" class="label">A</text>
<!-- Row 0 -->
<rect x="0" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="0" width="14" height="14" class="matrix-cell"/>
<!-- Row 1 (highlighted - row i) -->
<rect x="0" y="14" width="14" height="14" class="highlight-row"/>
<rect x="14" y="14" width="14" height="14" class="highlight-row"/>
<rect x="28" y="14" width="14" height="14" class="highlight-row"/>
<rect x="42" y="14" width="14" height="14" class="highlight-row"/>
<!-- Row 2 -->
<rect x="0" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="28" width="14" height="14" class="matrix-cell"/>
<!-- Row 3 -->
<rect x="0" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="42" width="14" height="14" class="matrix-cell"/>
<!-- Label -->
<text x="28" y="70" class="small-text" style="fill: #388e3c; font-weight: bold; text-anchor: middle;">rowwise</text>
</g>
<!-- Matrix B (middle) - 4x4, centered -->
<g id="matrix-b1" transform="translate(136, 52)">
<text x="28" y="-14" class="label">B</text>
<!-- Row 0 -->
<rect x="0" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="0" width="14" height="14" class="highlight-col"/>
<rect x="42" y="0" width="14" height="14" class="matrix-cell"/>
<!-- Row 1 -->
<rect x="0" y="14" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="14" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="14" width="14" height="14" class="highlight-col"/>
<rect x="42" y="14" width="14" height="14" class="matrix-cell"/>
<!-- Row 2 -->
<rect x="0" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="28" width="14" height="14" class="highlight-col"/>
<rect x="42" y="28" width="14" height="14" class="matrix-cell"/>
<!-- Row 3 -->
<rect x="0" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="42" width="14" height="14" class="highlight-col"/>
<rect x="42" y="42" width="14" height="14" class="matrix-cell"/>
<!-- Label -->
<text x="28" y="70" class="small-text" style="fill: #f57c00; font-weight: bold; text-anchor: middle;">columnwise</text>
</g>
<!-- Result matrix 4x4 (right), centered -->
<g id="matrix-c1" transform="translate(232, 52)">
<text x="28" y="-14" class="label" style="font-weight: bold;">A×B</text>
<!-- Row 0 -->
<rect x="0" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="0" width="14" height="14" class="matrix-cell"/>
<!-- Row 1 -->
<rect x="0" y="14" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="14" width="14" height="14" class="matrix-cell"/>
<!-- Highlighted result cell -->
<rect x="28" y="14" width="14" height="14" class="highlight-result"/>
<rect x="42" y="14" width="14" height="14" class="matrix-cell"/>
<!-- Row 2 -->
<rect x="0" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="28" width="14" height="14" class="matrix-cell"/>
<!-- Row 3 -->
<rect x="0" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="42" width="14" height="14" class="matrix-cell"/>
</g>
<!-- SECOND SCENARIO: A × B^T -->
<!-- Title -->
<text x="170" y="152" class="title">TN GEMM</text>
<!-- Matrix A (left), centered -->
<g id="matrix-a2" transform="translate(40, 188)">
<text x="28" y="-14" class="label">A</text>
<!-- Row 0 -->
<rect x="0" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="0" width="14" height="14" class="matrix-cell"/>
<!-- Row 1 (highlighted - row i) -->
<rect x="0" y="14" width="14" height="14" class="highlight-row"/>
<rect x="14" y="14" width="14" height="14" class="highlight-row"/>
<rect x="28" y="14" width="14" height="14" class="highlight-row"/>
<rect x="42" y="14" width="14" height="14" class="highlight-row"/>
<!-- Row 2 -->
<rect x="0" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="28" width="14" height="14" class="matrix-cell"/>
<!-- Row 3 -->
<rect x="0" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="42" width="14" height="14" class="matrix-cell"/>
<!-- Label -->
<text x="28" y="70" class="small-text" style="fill: #388e3c; font-weight: bold; text-anchor: middle;">rowwise</text>
</g>
<!-- Matrix B (middle, shown with rowwise access for B^T) - 4x4, centered -->
<g id="matrix-b2" transform="translate(136, 188)">
<text x="28" y="-14" class="label">B</text>
<!-- Row 0 -->
<rect x="0" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="0" width="14" height="14" class="matrix-cell"/>
<!-- Row 1 (highlighted - rowwise access for transposed) -->
<rect x="0" y="14" width="14" height="14" class="highlight-row"/>
<rect x="14" y="14" width="14" height="14" class="highlight-row"/>
<rect x="28" y="14" width="14" height="14" class="highlight-row"/>
<rect x="42" y="14" width="14" height="14" class="highlight-row"/>
<!-- Row 2 -->
<rect x="0" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="28" width="14" height="14" class="matrix-cell"/>
<!-- Row 3 -->
<rect x="0" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="42" width="14" height="14" class="matrix-cell"/>
<!-- Label -->
<text x="28" y="70" class="small-text" style="fill: #388e3c; font-weight: bold; text-anchor: middle;">rowwise</text>
</g>
<!-- Result matrix 4x4 (right), centered -->
<g id="matrix-c2" transform="translate(232, 188)">
<text x="28" y="-14" class="label" style="font-weight: bold;">A×B<tspan baseline-shift="super" font-size="7">T</tspan></text>
<!-- Row 0 -->
<rect x="0" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="0" width="14" height="14" class="matrix-cell"/>
<!-- Row 1 -->
<rect x="0" y="14" width="14" height="14" class="matrix-cell"/>
<!-- Highlighted result cell -->
<rect x="14" y="14" width="14" height="14" class="highlight-result"/>
<rect x="28" y="14" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="14" width="14" height="14" class="matrix-cell"/>
<!-- Row 2 -->
<rect x="0" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="28" width="14" height="14" class="matrix-cell"/>
<!-- Row 3 -->
<rect x="0" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="42" width="14" height="14" class="matrix-cell"/>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 700 220" width="100%" style="max-width: 700px;">
<defs>
<style>
.title { font: bold 14px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 11px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 10px sans-serif; fill: #555; text-anchor: middle; }
.cell-num { font: 9px sans-serif; fill: #1565c0; text-anchor: middle; dominant-baseline: middle; }
/* Matrix colors */
.matrix-cell { fill: #bbdefb; stroke: #1565c0; stroke-width: 1; }
</style>
</defs>
<!-- HOPPER SECTION (left) -->
<text x="175" y="25" class="title">FP8 tensor on Hopper</text>
<!-- Rowwise tensor -->
<g id="hopper-rowwise" transform="translate(50, 50)">
<text x="40" y="-8" class="label">rowwise</text>
<rect x="0" y="0" width="20" height="20" class="matrix-cell"/>
<text x="10" y="10" class="cell-num">0</text>
<rect x="20" y="0" width="20" height="20" class="matrix-cell"/>
<text x="30" y="10" class="cell-num">1</text>
<rect x="40" y="0" width="20" height="20" class="matrix-cell"/>
<text x="50" y="10" class="cell-num">2</text>
<rect x="60" y="0" width="20" height="20" class="matrix-cell"/>
<text x="70" y="10" class="cell-num">3</text>
<rect x="0" y="20" width="20" height="20" class="matrix-cell"/>
<text x="10" y="30" class="cell-num">4</text>
<rect x="20" y="20" width="20" height="20" class="matrix-cell"/>
<text x="30" y="30" class="cell-num">5</text>
<rect x="40" y="20" width="20" height="20" class="matrix-cell"/>
<text x="50" y="30" class="cell-num">6</text>
<rect x="60" y="20" width="20" height="20" class="matrix-cell"/>
<text x="70" y="30" class="cell-num">7</text>
<rect x="0" y="40" width="20" height="20" class="matrix-cell"/>
<text x="10" y="50" class="cell-num">8</text>
<rect x="20" y="40" width="20" height="20" class="matrix-cell"/>
<text x="30" y="50" class="cell-num">9</text>
<rect x="40" y="40" width="20" height="20" class="matrix-cell"/>
<text x="50" y="50" class="cell-num">10</text>
<rect x="60" y="40" width="20" height="20" class="matrix-cell"/>
<text x="70" y="50" class="cell-num">11</text>
</g>
<!-- Columnwise tensor (transposed) -->
<g id="hopper-colwise" transform="translate(200, 50)">
<text x="32" y="-8" class="label">columnwise</text>
<rect x="0" y="0" width="20" height="20" class="matrix-cell"/>
<text x="10" y="10" class="cell-num">0</text>
<rect x="20" y="0" width="20" height="20" class="matrix-cell"/>
<text x="30" y="10" class="cell-num">4</text>
<rect x="40" y="0" width="20" height="20" class="matrix-cell"/>
<text x="50" y="10" class="cell-num">8</text>
<rect x="0" y="20" width="20" height="20" class="matrix-cell"/>
<text x="10" y="30" class="cell-num">1</text>
<rect x="20" y="20" width="20" height="20" class="matrix-cell"/>
<text x="30" y="30" class="cell-num">5</text>
<rect x="40" y="20" width="20" height="20" class="matrix-cell"/>
<text x="50" y="30" class="cell-num">9</text>
<rect x="0" y="40" width="20" height="20" class="matrix-cell"/>
<text x="10" y="50" class="cell-num">2</text>
<rect x="20" y="40" width="20" height="20" class="matrix-cell"/>
<text x="30" y="50" class="cell-num">6</text>
<rect x="40" y="40" width="20" height="20" class="matrix-cell"/>
<text x="50" y="50" class="cell-num">10</text>
<rect x="0" y="60" width="20" height="20" class="matrix-cell"/>
<text x="10" y="70" class="cell-num">3</text>
<rect x="20" y="60" width="20" height="20" class="matrix-cell"/>
<text x="30" y="70" class="cell-num">7</text>
<rect x="40" y="60" width="20" height="20" class="matrix-cell"/>
<text x="50" y="70" class="cell-num">11</text>
</g>
<!-- Separator -->
<line x1="350" y1="20" x2="350" y2="200" stroke="#bdbdbd" stroke-width="1" stroke-dasharray="4,4"/>
<!-- BLACKWELL SECTION (right) -->
<text x="525" y="25" class="title">FP8 tensor on Blackwell</text>
<!-- Single tensor for both usages -->
<g id="blackwell-both" transform="translate(455, 50)">
<text x="70" y="-8" class="label">rowwise and columnwise</text>
<rect x="30" y="0" width="20" height="20" class="matrix-cell"/>
<text x="40" y="10" class="cell-num">0</text>
<rect x="50" y="0" width="20" height="20" class="matrix-cell"/>
<text x="60" y="10" class="cell-num">1</text>
<rect x="70" y="0" width="20" height="20" class="matrix-cell"/>
<text x="80" y="10" class="cell-num">2</text>
<rect x="90" y="0" width="20" height="20" class="matrix-cell"/>
<text x="100" y="10" class="cell-num">3</text>
<rect x="30" y="20" width="20" height="20" class="matrix-cell"/>
<text x="40" y="30" class="cell-num">4</text>
<rect x="50" y="20" width="20" height="20" class="matrix-cell"/>
<text x="60" y="30" class="cell-num">5</text>
<rect x="70" y="20" width="20" height="20" class="matrix-cell"/>
<text x="80" y="30" class="cell-num">6</text>
<rect x="90" y="20" width="20" height="20" class="matrix-cell"/>
<text x="100" y="30" class="cell-num">7</text>
<rect x="30" y="40" width="20" height="20" class="matrix-cell"/>
<text x="40" y="50" class="cell-num">8</text>
<rect x="50" y="40" width="20" height="20" class="matrix-cell"/>
<text x="60" y="50" class="cell-num">9</text>
<rect x="70" y="40" width="20" height="20" class="matrix-cell"/>
<text x="80" y="50" class="cell-num">10</text>
<rect x="90" y="40" width="20" height="20" class="matrix-cell"/>
<text x="100" y="50" class="cell-num">11</text>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 850 660" width="850" height="660">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
/* Arrows */
.arrow { stroke: #616161; stroke-width: 2; fill: none; marker-end: url(#arrowhead); }
/* Phase labels */
.phase-forward { fill: #1565c0; font-weight: 600; }
.phase-backward { fill: #c62828; font-weight: 600; }
/* All-gather operation */
.allgather {
fill: #e1f5fe;
stroke: #039be5;
stroke-width: 2;
}
</style>
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- Main Title -->
<text x="425" y="30" class="title" style="font-size: 18px; font-weight: 700; fill: #1a1a1a;">All-Gather of Quantized Tensors (one scenario)</text>
<!-- Section 1: Input Tensor -->
<text x="425" y="70" class="section-title" style="fill: #212121; font-weight: 600;">Input Tensor quantized all-gather</text>
<!-- Forward phase label -->
<text x="20" y="128" class="text phase-forward" style="text-anchor: start;">FWD:</text>
<!-- High Precision Input -->
<rect x="90" y="100" width="130" height="55" class="hp" rx="6"/>
<text x="155" y="123" class="text">High Precision</text>
<text x="155" y="140" class="text">Tensor</text>
<!-- Arrow -->
<path d="M 220 127 L 250 127" class="arrow"/>
<!-- Quantize -->
<rect x="250" y="100" width="100" height="55" class="quantize" rx="6"/>
<text x="300" y="132" class="text">Quantize</text>
<!-- Arrow -->
<path d="M 350 127 L 380 127" class="arrow"/>
<!-- Rowwise Quantized Tensor -->
<rect x="380" y="100" width="130" height="55" class="fp8" rx="6"/>
<text x="445" y="123" class="text">Rowwise</text>
<text x="445" y="140" class="text">Quantized</text>
<!-- Arrow -->
<path d="M 510 127 L 540 127" class="arrow"/>
<!-- All-Gather -->
<rect x="540" y="100" width="110" height="55" class="allgather" rx="6"/>
<text x="595" y="132" class="text">All-Gather</text>
<!-- Arrow -->
<path d="M 650 127 L 680 127" class="arrow"/>
<!-- Continue indicator -->
<text x="705" y="132" class="text">...</text>
<!-- Backward phase label -->
<text x="20" y="228" class="text phase-backward" style="text-anchor: start;">BWD:</text>
<!-- Arrow from Quantize to Columnwise (showing separate quantization) -->
<path d="M 300 155 L 300 180 L 380 227" class="arrow"/>
<!-- Columnwise Quantized Tensor -->
<rect x="380" y="200" width="130" height="55" class="fp8" rx="6"/>
<text x="445" y="223" class="text">Columnwise</text>
<text x="445" y="240" class="text">Quantized</text>
<!-- Arrow -->
<path d="M 510 227 L 540 227" class="arrow"/>
<!-- All-Gather -->
<rect x="540" y="200" width="110" height="55" class="allgather" rx="6"/>
<text x="595" y="232" class="text">All-Gather</text>
<!-- Arrow -->
<path d="M 650 227 L 680 227" class="arrow"/>
<!-- Continue indicator -->
<text x="705" y="232" class="text">...</text>
<!-- Divider Line -->
<line x1="30" y1="305" x2="820" y2="305" stroke="#bdbdbd" stroke-width="3"/>
<!-- Section 2: Gradient Tensor -->
<text x="425" y="340" class="section-title" style="fill: #212121; font-weight: 600;">Gradient Tensor quantized all-gather</text>
<!-- Backward phase label -->
<text x="20" y="413" class="text phase-backward" style="text-anchor: start;">BWD:</text>
<!-- High Precision Gradient -->
<rect x="90" y="385" width="130" height="55" class="hp" rx="6"/>
<text x="155" y="408" class="text">High Precision</text>
<text x="155" y="425" class="text">Tensor</text>
<!-- Arrow -->
<path d="M 220 412 L 250 412" class="arrow"/>
<!-- Quantize -->
<rect x="250" y="385" width="100" height="55" class="quantize" rx="6"/>
<text x="300" y="417" class="text">Quantize</text>
<!-- Arrow to Columnwise -->
<path d="M 350 397 L 380 377" class="arrow"/>
<!-- Columnwise Quantized Tensor -->
<rect x="380" y="355" width="130" height="45" class="fp8" rx="6"/>
<text x="445" y="382" class="text">Col. Quantized</text>
<!-- Arrow to Rowwise -->
<path d="M 350 427 L 380 447" class="arrow"/>
<!-- Rowwise Quantized Tensor -->
<rect x="380" y="425" width="130" height="45" class="fp8" rx="6"/>
<text x="445" y="452" class="text">Row. Quantized</text>
<!-- Arrow from Columnwise to All-Gather -->
<path d="M 510 377 L 540 412" class="arrow"/>
<!-- Arrow from Rowwise to All-Gather -->
<path d="M 510 447 L 540 412" class="arrow"/>
<!-- All-Gather (single, shared) -->
<rect x="540" y="390" width="110" height="45" class="allgather" rx="6"/>
<text x="595" y="417" class="text">All-Gather</text>
<!-- Arrow -->
<path d="M 650 412 L 680 412" class="arrow"/>
<!-- Continue indicator -->
<text x="705" y="417" class="text">...</text>
<!-- Legend -->
<g transform="translate(100, 515)">
<rect x="0" y="0" width="80" height="40" rx="5" class="hp"/>
<text x="95" y="23" class="text" style="text-anchor: start;">High Precision (FP32/BF16/FP16)</text>
<rect x="380" y="0" width="80" height="40" rx="5" class="fp8"/>
<text x="475" y="23" class="text" style="text-anchor: start;">Lower Precision (FP8, etc.)</text>
<rect x="0" y="55" width="80" height="40" rx="5" class="quantize"/>
<text x="95" y="78" class="text" style="text-anchor: start;">Quantization</text>
<rect x="380" y="55" width="80" height="40" rx="5" class="allgather"/>
<text x="475" y="78" class="text" style="text-anchor: start;">All-Gather</text>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 850 750" width="850" height="750">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
/* Arrows */
.arrow { stroke: #616161; stroke-width: 2; fill: none; marker-end: url(#arrowhead); }
.arrow-dashed { stroke: #616161; stroke-width: 2; fill: none; stroke-dasharray: 5,5; marker-end: url(#arrowhead); }
/* Phase labels */
.phase-forward { fill: #1565c0; font-weight: 600; }
.phase-backward { fill: #c62828; font-weight: 600; }
</style>
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- OPTION 1: Both usages in forward -->
<text x="425" y="35" class="section-title" style="fill: #212121; font-weight: 600;">Option 1: Quantize both usages in forward</text>
<!-- Forward phase label -->
<text x="20" y="113" class="text phase-forward" style="text-anchor: start;">FORWARD:</text>
<!-- High Precision Tensor -->
<rect x="130" y="85" width="130" height="55" class="hp" rx="6"/>
<text x="195" y="108" class="text">High Precision</text>
<text x="195" y="125" class="text">Tensor</text>
<!-- Arrow -->
<path d="M 260 112 L 290 112" class="arrow"/>
<!-- Quantize -->
<rect x="290" y="85" width="160" height="55" class="quantize" rx="6"/>
<text x="370" y="117" class="text">Quantize</text>
<!-- Arrow to FP8 -->
<path d="M 450 112 L 480 112" class="arrow"/>
<!-- Quantized Tensor (for forward) -->
<rect x="480" y="85" width="130" height="55" class="fp8" rx="6"/>
<text x="545" y="108" class="text">Quantized</text>
<text x="545" y="125" class="text">Rowwise</text>
<!-- Backward phase label -->
<text x="20" y="163" class="text phase-backward" style="text-anchor: start;">BACKWARD:</text>
<!-- Arrow to FP8 columnwise -->
<path d="M 450 130 L 480 152" class="arrow"/>
<!-- Quantized columnwise (saved for backward) -->
<rect x="480" y="150" width="130" height="55" class="fp8" rx="6"/>
<text x="545" y="173" class="text">Quantized</text>
<text x="545" y="190" class="text">Columnwise</text>
<!-- Divider Line -->
<line x1="30" y1="220" x2="820" y2="220" stroke="#bdbdbd" stroke-width="3"/>
<!-- OPTION 2: Separate approach -->
<text x="425" y="250" class="section-title" style="fill: #212121; font-weight: 600;">Option 2: Separate Quantizations (quantize when needed)</text>
<!-- Forward phase -->
<text x="20" y="293" class="text phase-forward" style="text-anchor: start;">FORWARD:</text>
<!-- High Precision Tensor (shared) -->
<rect x="130" y="265" width="130" height="55" class="hp" rx="6"/>
<text x="195" y="288" class="text">High Precision</text>
<text x="195" y="305" class="text">Tensor</text>
<!-- Arrow to Quantize -->
<path d="M 260 292 L 290 292" class="arrow"/>
<!-- Quantize only -->
<rect x="290" y="265" width="120" height="55" class="quantize" rx="6"/>
<text x="350" y="297" class="text">Quantize</text>
<!-- Arrow to FP8 -->
<path d="M 410 292 L 440 292" class="arrow"/>
<!-- Quantized Tensor -->
<rect x="440" y="265" width="130" height="55" class="fp8" rx="6"/>
<text x="505" y="288" class="text">Quantized</text>
<text x="505" y="305" class="text">Rowwise</text>
<!-- Dashed line showing it's the same tensor -->
<path d="M 195 320 L 195 360" class="arrow-dashed"/>
<!-- Backward phase -->
<text x="20" y="393" class="text phase-backward" style="text-anchor: start;">BACKWARD:</text>
<!-- High Precision Tensor (same, reused) -->
<rect x="130" y="360" width="130" height="55" class="hp" rx="6"/>
<text x="195" y="383" class="text">High Precision</text>
<text x="195" y="400" class="text">Tensor</text>
<!-- Arrow to Quantize -->
<path d="M 260 387 L 290 387" class="arrow"/>
<!-- Quantize (in backward) -->
<rect x="290" y="360" width="160" height="55" class="quantize" rx="6"/>
<text x="370" y="392" class="text">Quantize</text>
<!-- Arrow to FP8 columnwise -->
<path d="M 450 387 L 480 387" class="arrow"/>
<!-- Quantized columnwise -->
<rect x="480" y="360" width="130" height="55" class="fp8" rx="6"/>
<text x="545" y="383" class="text">Quantized</text>
<text x="545" y="400" class="text">Columnwise</text>
<!-- Divider Line -->
<line x1="30" y1="430" x2="820" y2="430" stroke="#bdbdbd" stroke-width="3"/>
<!-- OPTION 3: Convert from rowwise to columnwise -->
<text x="425" y="460" class="section-title" style="fill: #212121; font-weight: 600;">Option 3: Convert Rowwise to Columnwise in Backward (reuse saved tensor)</text>
<!-- Forward phase -->
<text x="20" y="503" class="text phase-forward" style="text-anchor: start;">FORWARD:</text>
<!-- High Precision Tensor -->
<rect x="130" y="475" width="130" height="55" class="hp" rx="6"/>
<text x="195" y="498" class="text">High Precision</text>
<text x="195" y="515" class="text">Tensor</text>
<!-- Arrow to Quantize -->
<path d="M 260 502 L 290 502" class="arrow"/>
<!-- Quantize only -->
<rect x="290" y="475" width="120" height="55" class="quantize" rx="6"/>
<text x="350" y="507" class="text">Quantize</text>
<!-- Arrow to FP8 -->
<path d="M 410 502 L 440 502" class="arrow"/>
<!-- Quantized Tensor (saved) -->
<rect x="440" y="475" width="130" height="55" class="fp8" rx="6"/>
<text x="505" y="498" class="text">Quantized</text>
<text x="505" y="515" class="text">Rowwise</text>
<!-- Dashed line showing tensor is saved -->
<path d="M 505 530 L 505 547 L 195 547 L 195 565" class="arrow-dashed"/>
<!-- Backward phase -->
<text x="20" y="598" class="text phase-backward" style="text-anchor: start;">BACKWARD:</text>
<!-- Quantized Tensor (reused) -->
<rect x="130" y="565" width="130" height="55" class="fp8" rx="6"/>
<text x="195" y="588" class="text">Quantized</text>
<text x="195" y="605" class="text">Rowwise</text>
<!-- Arrow to Make Columnwise -->
<path d="M 260 592 L 290 592" class="arrow"/>
<!-- Make Columnwise operation -->
<rect x="290" y="565" width="120" height="55" class="quantize" rx="6"/>
<text x="350" y="585" class="text">Make</text>
<text x="350" y="602" class="text">Columnwise</text>
<!-- Arrow to FP8 columnwise -->
<path d="M 410 592 L 440 592" class="arrow"/>
<!-- Quantized columnwise -->
<rect x="440" y="565" width="130" height="55" class="fp8" rx="6"/>
<text x="505" y="588" class="text">Quantized</text>
<text x="505" y="605" class="text">Columnwise</text>
<!-- Legend -->
<g transform="translate(150, 660)">
<rect x="0" y="0" width="70" height="35" rx="5" class="hp"/>
<text x="85" y="21" class="small-text" style="text-anchor: start;">High Precision (FP32/BF16/FP16)</text>
<rect x="380" y="0" width="70" height="35" rx="5" class="fp8"/>
<text x="465" y="21" class="small-text" style="text-anchor: start;">Lower Precision (FP8, etc.)</text>
<rect x="0" y="45" width="70" height="35" rx="5" class="quantize"/>
<text x="85" y="66" class="small-text" style="text-anchor: start;">Quantization / Make Columnwise</text>
</g>
</svg>
# START_MEMORY_USAGE_1
Tensors in memory:
Shape: (1024, 1024), Dtype: bfloat16, Size: 2048.0 KB
Shape: (1024, 1024), Dtype: bfloat16, Size: 2048.0 KB
Total from all live arrays: 4.00 MB
# END_MEMORY_USAGE_1
Processing events...
Generated:
No reports were generated
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
print("# START_MEMORY_USAGE_1")
import jax
import jax.numpy as jnp
from transformer_engine.jax.flax import DenseGeneral
key = jax.random.PRNGKey(0)
jax.clear_caches()
# Initialize layer with BF16 parameters
layer = DenseGeneral(features=1024, dtype=jnp.bfloat16)
x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16)
var_collect = layer.init(key, x)
@jax.jit
def loss_fn(var_collect, x):
output = layer.apply(var_collect, x)
return output.sum()
# Trace the backward pass - this allocates saved tensors
_, backward_fn = jax.vjp(loss_fn, var_collect, x)
del x
print("Tensors in memory:")
total_bytes = 0
for arr in jax.live_arrays():
total_bytes += arr.nbytes
if arr.nbytes > 200000: # do not count small tensors
print(f" Shape: {arr.shape}, Dtype: {arr.dtype}, Size: {arr.nbytes / 1024:.1f} KB")
print(f" Total from all live arrays: {total_bytes / (1024**2):.2f} MB")
print("# END_MEMORY_USAGE_1")
# START_MEMORY_USAGE_1
Memory usage after forward pass: 6.00 MB
# END_MEMORY_USAGE_1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment