"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "d6fa1be3a8ef71fa16f74afdc5d07d27cbf725b1"
Unverified Commit 3ceb248e authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

More detailed documentation for recipes (#2343)



* Code drop: Update recipes documentation and remove custom recipes from low precision training
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Fix SVG css import path for diagrams
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Refactor low_precision_training docs: remove optimizers, fix imports, add GPU checks

Changes:
- Remove optimizer code from all recipe examples (keep only forward/backward)
- Fix Format imports (use Format.E4M3 instead of string 'E4M3')
- Fix params_dtype for PyTorch examples (add params_dtype=torch.bfloat16)
- Add GPU capability assertions before START blocks for blockwise/mxfp8/nvfp4
- Fix JAX imports (Float8CurrentScaling from common.recipe, NVFP4BlockScaling)
- Add global_shard_guard for TransformerLayer examples in JAX
- Fix fused_layers_jax.py return tuple unpacking
- Update memory_usage JAX examples with dynamic GPU measurement
- Remove memory_usage_3_jax (JAX doesn't support FP8 weight storage)
- Update performance_considerations.rst for JAX differences
- Delete unused .out files and fp8_autocast_jax.py
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix JAX memory usage .out files with correct output
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* responded to comments
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* applied suggestions form greptile
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* year change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* jax compute capability fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c3769cb7
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Check for Blackwell or newer GPU
major, minor = torch.cuda.get_device_capability()
assert major >= 10, f"MXFP8 requires SM100 (Blackwell) or later, got SM{major}{minor}"
# START_MXFP8_EXAMPLE
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import MXFP8BlockScaling, Format
# Create MXFP8 recipe
recipe = MXFP8BlockScaling(
fp8_format=Format.E4M3, # E4M3 (default) or HYBRID; pure E5M2 not supported
)
# Create a linear layer with bfloat16 parameters
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
# Forward and backward pass
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
loss = output.sum()
loss.backward()
# END_MXFP8_EXAMPLE
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1100 140" width="1100" height="140">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 16px sans-serif; fill: #333; text-anchor: middle; }
.text { font: 13px sans-serif; fill: #333; text-anchor: middle; }
/* Arrows */
.arrow { stroke: #616161; stroke-width: 2; fill: none; marker-end: url(#arrowhead); }
/* High precision tensor */
.hp {
fill: #e8f5e9;
stroke: #43a047;
stroke-width: 2;
}
/* Amax operations */
.amax {
fill: #fff3e0;
stroke: #ff9800;
stroke-width: 2;
}
/* Quantize operations */
.quantize {
fill: #fce4ec;
stroke: #e91e63;
stroke-width: 2;
}
/* NVFP4 tensor */
.nvfp4 {
fill: #87CEEB;
stroke: #444;
stroke-width: 2;
}
/* All-gather operations */
.allgather {
fill: #e1f5fe;
stroke: #039be5;
stroke-width: 2;
}
</style>
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- Title -->
<text x="550" y="30" class="title">Quantization + All-Gather for NVFP4</text>
<!-- High Precision Tensor -->
<rect x="20" y="70" width="100" height="55" class="hp" rx="6"/>
<text x="70" y="93" class="text">High Precision</text>
<text x="70" y="110" class="text">Tensor</text>
<!-- Arrow -->
<path d="M 120 97 L 145 97" class="arrow"/>
<!-- Compute Amax -->
<rect x="145" y="70" width="100" height="55" class="amax" rx="6"/>
<text x="195" y="93" class="text">Compute</text>
<text x="195" y="110" class="text">Amax</text>
<!-- Arrow -->
<path d="M 245 97 L 270 97" class="arrow"/>
<!-- Synchronize Amax -->
<rect x="270" y="70" width="100" height="55" class="amax" rx="6"/>
<text x="320" y="93" class="text">Synchronize</text>
<text x="320" y="110" class="text">Amax</text>
<!-- Arrow -->
<path d="M 370 97 L 395 97" class="arrow"/>
<!-- Compute s_global -->
<rect x="395" y="70" width="100" height="55" class="amax" rx="6"/>
<text x="445" y="93" class="text">Compute</text>
<text x="445" y="110" class="text">s_global</text>
<!-- Arrow -->
<path d="M 495 97 L 520 97" class="arrow"/>
<!-- Scale + Cast -->
<rect x="520" y="70" width="100" height="55" class="quantize" rx="6"/>
<text x="570" y="86" class="text">Scale + Cast</text>
<text x="570" y="103" class="text">(s_block,</text>
<text x="570" y="118" class="text">s_global)</text>
<!-- Arrow -->
<path d="M 620 97 L 645 97" class="arrow"/>
<!-- NVFP4 Tensor (intermediate) -->
<rect x="645" y="70" width="100" height="55" class="nvfp4" rx="6"/>
<text x="695" y="93" class="text">NVFP4</text>
<text x="695" y="110" class="text">Tensor</text>
<!-- Arrow -->
<path d="M 745 97 L 770 97" class="arrow"/>
<!-- All-Gather -->
<rect x="770" y="70" width="100" height="55" class="allgather" rx="6"/>
<text x="820" y="102" class="text">All-Gather</text>
<!-- Arrow -->
<path d="M 870 97 L 895 97" class="arrow"/>
<!-- NVFP4 Gathered Tensor -->
<rect x="895" y="70" width="130" height="55" class="nvfp4" rx="6"/>
<text x="960" y="93" class="text">NVFP4 Gathered</text>
<text x="960" y="110" class="text">Tensor</text>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 60 500 450" width="100%" style="max-width: 500px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 18px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 14px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 12px sans-serif; fill: #555; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
/* Tensor colors */
.fp8-tensor { fill: #87CEEB; stroke: #444; stroke-width: 2; }
.fp8-block { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.fp8-block-alt { fill: #5F9FCC; stroke: #555; stroke-width: 1.5; }
/* Scaling factor colors */
.scale-factor { fill: #FFA500; stroke: #444; stroke-width: 2; }
.global-scale { fill: #FF6B6B; stroke: #444; stroke-width: 2; }
.grid-line { stroke: #444; stroke-width: 2; }
.boundary-line { stroke: #444; stroke-width: 2; }
</style>
</defs>
<!-- NVFP4 Scaling -->
<g id="nvfp4-scaling">
<text x="250" y="85" class="title">NVFP4 Hierarchical Scaling</text>
<text x="250" y="108" class="label" style="font-size: 12px;">(Block scaling + Global scaling)</text>
<!-- NVFP4 Tensor split into many small blocks (40×10) -->
<g id="tensor-blocks">
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="250.0" y="140.0" width="40" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="130.0" y="190.0" width="240" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Blocks ONLY where they don't overlap with white cross -->
<rect x="130" y="140" width="40" height="10" class="fp8-block"/>
<rect x="170" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="140" width="40" height="10" class="fp8-block"/>
<rect x="290" y="140" width="40" height="10" class="fp8-block"/>
<rect x="330" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="150" width="40" height="10" class="fp8-block"/>
<rect x="210" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="150" width="40" height="10" class="fp8-block"/>
<rect x="130" y="160" width="40" height="10" class="fp8-block"/>
<rect x="170" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="160" width="40" height="10" class="fp8-block"/>
<rect x="290" y="160" width="40" height="10" class="fp8-block"/>
<rect x="330" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="170" width="40" height="10" class="fp8-block"/>
<rect x="210" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="170" width="40" height="10" class="fp8-block"/>
<rect x="130" y="180" width="40" height="10" class="fp8-block"/>
<rect x="170" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="180" width="40" height="10" class="fp8-block"/>
<rect x="290" y="180" width="40" height="10" class="fp8-block"/>
<rect x="330" y="180" width="40" height="10" class="fp8-block-alt"/>
<!-- Three dots in VERTICAL white bar -->
<text x="270" y="167.5" class="dots-text"></text>
<text x="270" y="242.5" class="dots-text"></text>
<!-- Three dots in HORIZONTAL white bar -->
<text x="190" y="205" class="dots-text"></text>
<text x="330" y="205" class="dots-text"></text>
<!-- ONE diagonal dot at intersection -->
<text x="270" y="205" class="dots-text" transform="rotate(45 270 205)"></text>
<!-- Bottom rows (y >= 220 after horizontal white bar) -->
<rect x="130" y="220" width="40" height="10" class="fp8-block"/>
<rect x="170" y="220" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="220" width="40" height="10" class="fp8-block"/>
<rect x="290" y="220" width="40" height="10" class="fp8-block"/>
<rect x="330" y="220" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="230" width="40" height="10" class="fp8-block"/>
<rect x="210" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="230" width="40" height="10" class="fp8-block"/>
<rect x="130" y="240" width="40" height="10" class="fp8-block"/>
<rect x="170" y="240" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="240" width="40" height="10" class="fp8-block"/>
<rect x="290" y="240" width="40" height="10" class="fp8-block"/>
<rect x="330" y="240" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="250" width="40" height="10" class="fp8-block"/>
<rect x="210" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="250" width="40" height="10" class="fp8-block"/>
<!-- Main outline -->
<rect x="130.0" y="140.0" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Scaling factors tensor - 3+2 columns of 10px squares -->
<g id="scale-factors">
<!-- Orange background -->
<rect x="215" y="285" width="70" height="120" fill="#FFA500"/>
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="245" y="285" width="20" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="215" y="335" width="70" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Grid lines showing 10x10 squares (3 left + 2 right columns) -->
<!-- Vertical lines every 10px (skipping white space) -->
<!-- Left 3 columns -->
<line x1="225" y1="285" x2="225" y2="335" class="grid-line" stroke-width="1"/>
<line x1="235" y1="285" x2="235" y2="335" class="grid-line" stroke-width="1"/>
<line x1="245" y1="285" x2="245" y2="335" class="grid-line" stroke-width="1"/>
<!-- Right 2 columns -->
<line x1="265" y1="285" x2="265" y2="335" class="grid-line" stroke-width="1"/>
<line x1="275" y1="285" x2="275" y2="335" class="grid-line" stroke-width="1"/>
<line x1="285" y1="285" x2="285" y2="335" class="grid-line" stroke-width="1"/>
<!-- Bottom sections -->
<line x1="225" y1="365" x2="225" y2="405" class="grid-line" stroke-width="1"/>
<line x1="235" y1="365" x2="235" y2="405" class="grid-line" stroke-width="1"/>
<line x1="245" y1="365" x2="245" y2="405" class="grid-line" stroke-width="1"/>
<line x1="265" y1="365" x2="265" y2="405" class="grid-line" stroke-width="1"/>
<line x1="275" y1="365" x2="275" y2="405" class="grid-line" stroke-width="1"/>
<line x1="285" y1="365" x2="285" y2="405" class="grid-line" stroke-width="1"/>
<!-- Horizontal lines every 10px -->
<line x1="215" y1="295" x2="245" y2="295" class="grid-line" stroke-width="1"/>
<line x1="265" y1="295" x2="285" y2="295" class="grid-line" stroke-width="1"/>
<line x1="215" y1="305" x2="245" y2="305" class="grid-line" stroke-width="1"/>
<line x1="265" y1="305" x2="285" y2="305" class="grid-line" stroke-width="1"/>
<line x1="215" y1="315" x2="245" y2="315" class="grid-line" stroke-width="1"/>
<line x1="265" y1="315" x2="285" y2="315" class="grid-line" stroke-width="1"/>
<line x1="215" y1="325" x2="245" y2="325" class="grid-line" stroke-width="1"/>
<line x1="265" y1="325" x2="285" y2="325" class="grid-line" stroke-width="1"/>
<!-- Top bottom boundaries -->
<line x1="215" y1="335" x2="245" y2="335" class="grid-line" stroke-width="1"/>
<line x1="265" y1="335" x2="285" y2="335" class="grid-line" stroke-width="1"/>
<line x1="215" y1="365" x2="245" y2="365" class="grid-line" stroke-width="1"/>
<line x1="265" y1="365" x2="285" y2="365" class="grid-line" stroke-width="1"/>
<line x1="215" y1="375" x2="245" y2="375" class="grid-line" stroke-width="1"/>
<line x1="265" y1="375" x2="285" y2="375" class="grid-line" stroke-width="1"/>
<line x1="215" y1="385" x2="245" y2="385" class="grid-line" stroke-width="1"/>
<line x1="265" y1="385" x2="285" y2="385" class="grid-line" stroke-width="1"/>
<line x1="215" y1="395" x2="245" y2="395" class="grid-line" stroke-width="1"/>
<line x1="265" y1="395" x2="285" y2="395" class="grid-line" stroke-width="1"/>
<!-- Bottom boundaries -->
<line x1="215" y1="405" x2="245" y2="405" class="grid-line" stroke-width="1"/>
<line x1="265" y1="405" x2="285" y2="405" class="grid-line" stroke-width="1"/>
<!-- Main outline -->
<rect x="215" y="285" width="70" height="120" fill="none" stroke="#444" stroke-width="2"/>
<!-- Three dots -->
<text x="255" y="312.5" class="dots-text" style="font-size: 14px;"></text>
<text x="255" y="387.5" class="dots-text" style="font-size: 14px;"></text>
<text x="230" y="350" class="dots-text" style="font-size: 14px;"></text>
<text x="275" y="350" class="dots-text" style="font-size: 14px;"></text>
<text x="255" y="350" class="dots-text" style="font-size: 14px;" transform="rotate(45 255 350)"></text>
</g>
<text x="250" y="430" class="small-text" text-anchor="middle">E4M3 scaling factors (one per 16 elements)</text>
<!-- Global Scaling Factor -->
<g id="global-scale" transform="translate(350, 320)">
<rect x="10" y="10" width="20" height="20" class="global-scale"/>
<text x="20" y="60" class="small-text" text-anchor="middle">Global Scale (FP32)</text>
<text x="20" y="75" class="small-text" text-anchor="middle">(one per tensor)</text>
<text x="-20" y="25" class="dots-text" style="font-size: 20px;">+</text>
</g>
</g>
</svg>
\ No newline at end of file
This diff is collapsed.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Check for Blackwell or newer GPU
from transformer_engine.jax.quantize import get_device_compute_capability
assert (
get_device_compute_capability() >= 100
), f"NVFP4 requires SM100 (Blackwell) or later, got SM{get_device_compute_capability()}"
# START_NVFP4_EXAMPLE
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import DenseGeneral
from transformer_engine.common.recipe import NVFP4BlockScaling
# Define NVFP4 recipe
# 2D weight quantization and RHT are enabled by default
recipe = NVFP4BlockScaling()
# To disable features, use:
# recipe = NVFP4BlockScaling(disable_rht=True, disable_2d_quantization=True)
with te.autocast(enabled=True, recipe=recipe):
# Initialize layer and data
layer = DenseGeneral(features=1024)
key, sr_key = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
# NVFP4 requires sr_rng for stochastic rounding
rngs = {"sr_rng": sr_key}
var_collect = layer.init({"params": key, "sr_rng": sr_key}, x)
# Forward and backward pass
def loss_fn(var_collect):
output = layer.apply(var_collect, x, rngs=rngs)
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
# END_NVFP4_EXAMPLE
This diff is collapsed.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Check for Blackwell or newer GPU
major, minor = torch.cuda.get_device_capability()
assert major >= 10, f"NVFP4 requires SM100 (Blackwell) or later, got SM{major}{minor}"
# START_NVFP4_EXAMPLE
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import NVFP4BlockScaling
# Define NVFP4 recipe
# 2D weight quantization and RHT are enabled by default
recipe = NVFP4BlockScaling()
# To disable features, use:
# recipe = NVFP4BlockScaling(disable_rht=True, disable_2d_quantization=True)
# Create a linear layer with bfloat16 parameters
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
# Forward and backward pass
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
loss = output.sum()
loss.backward()
# END_NVFP4_EXAMPLE
# START_MEMORY_USAGE_1
Tensors in memory:
Shape: (1024, 1024), Dtype: bfloat16, Size: 2048.0 KB
Shape: (1024, 1024), Dtype: bfloat16, Size: 2048.0 KB
Total from all live arrays: 4.00 MB
# END_MEMORY_USAGE_1
Processing events...
Generated:
No reports were generated
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment