Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
cc = torch.cuda.get_device_capability()
assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)"
# START_FUSED_LAYERS
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
# Example 1: Separate LayerNorm and Linear layers
layer_norm = te.LayerNorm(1024)
linear = te.Linear(1024, 1024)
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
# Two separate operations: LayerNorm produces FP32, then Linear quantizes it
normalized = layer_norm(inp)
output_separate = linear(normalized)
# Example 2: Fused LayerNormLinear layer
fused_layer = te.LayerNormLinear(1024, 1024, params_dtype=torch.bfloat16)
# Single operation: LayerNorm output is directly quantized
recipe = DelayedScaling()
with te.autocast(enabled=True, recipe=recipe):
output_fused = fused_layer(inp)
# The fused layer is more efficient as it avoids redundant quantization
# END_FUSED_LAYERS
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 850 640" width="850" height="640">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
/* Diagram-specific overrides */
.hp { fill: #f3e5f5; stroke: #7b1fa2; stroke-width: 2.5; }
.gemm { fill: #ffe8cc; stroke: #f57c00; stroke-width: 3; }
.quantize { fill: #c8e6c9; stroke: #43a047; stroke-width: 2.5; }
.text { font-weight: 500; }
.small-text { fill: #616161; }
.title { font-weight: 700; fill: #1a1a1a; }
.arrow { stroke: #424242; stroke-width: 2.5; fill: none; marker-end: url(#arrowhead); }
</style>
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#424242" />
</marker>
</defs>
<!-- Title -->
<text x="425" y="30" class="title">LayerNorm + Linear: Separate vs Fused</text>
<!-- Divider Line -->
<line x1="425" y1="50" x2="425" y2="620" stroke="#bdbdbd" stroke-width="2.5" stroke-dasharray="5,5"/>
<!-- LEFT: Separate Layers -->
<text x="212" y="65" class="section-title">Scenario 1: Separate Layers</text>
<!-- Input -->
<rect x="137" y="95" width="150" height="55" class="hp" rx="8"/>
<text x="212" y="128" class="text">Input</text>
<!-- Arrow -->
<path d="M 212 150 L 212 180" class="arrow"/>
<!-- LayerNorm -->
<rect x="137" y="180" width="150" height="55" class="layernorm" rx="8"/>
<text x="212" y="213" class="text" style="font-weight: 600;">LayerNorm</text>
<!-- Arrow -->
<path d="M 212 235 L 212 260" class="arrow"/>
<!-- Output FP32 -->
<rect x="147" y="260" width="130" height="50" class="hp" rx="8"/>
<text x="212" y="290" class="text">Output</text>
<!-- Arrow -->
<path d="M 212 310 L 212 330" class="arrow"/>
<!-- Linear wrapper -->
<rect x="122" y="330" width="180" height="190" class="gemm" rx="10"/>
<text x="212" y="355" class="text" style="font-weight: 600;">Linear</text>
<!-- Quantize inside -->
<rect x="142" y="368" width="140" height="30" class="quantize" rx="8"/>
<text x="212" y="388" class="text">Quantize</text>
<!-- Arrow -->
<path d="M 212 398 L 212 418" class="arrow"/>
<!-- FP8 tensor -->
<rect x="147" y="418" width="130" height="35" class="fp8" rx="8"/>
<text x="212" y="441" class="text">FP8 tensor</text>
<!-- Arrow -->
<path d="M 212 453 L 212 473" class="arrow"/>
<!-- Simplified computation inside -->
<rect x="142" y="473" width="140" height="30" class="computation" rx="8"/>
<text x="212" y="493" class="title" style="font-weight: 600;">...</text>
<!-- Arrow out of Linear -->
<path d="M 212 520 L 212 545" class="arrow"/>
<!-- Output -->
<rect x="137" y="545" width="150" height="50" class="hp" rx="8"/>
<text x="212" y="575" class="text">Output</text>
<!-- RIGHT: Fused Layer -->
<text x="637" y="65" class="section-title">Scenario 2: Fused Layer</text>
<!-- Input -->
<rect x="562" y="95" width="150" height="55" class="hp" rx="8"/>
<text x="637" y="128" class="text">Input</text>
<!-- Arrow -->
<path d="M 637 150 L 637 185" class="arrow"/>
<!-- LayerNormLinear Fused -->
<rect x="517" y="185" width="240" height="220" class="fused" rx="10"/>
<text x="637" y="212" class="text" style="font-weight: 600;">LayerNormLinear</text>
<!-- Inside fused block -->
<!-- LayerNorm -->
<rect x="537" y="235" width="200" height="40" class="layernorm" rx="8"/>
<text x="637" y="260" class="text">LayerNorm + Quantize</text>
<!-- Arrow -->
<path d="M 637 275 L 637 295" class="arrow"/>
<!-- FP8 tensor -->
<rect x="562" y="295" width="150" height="35" class="fp8" rx="8"/>
<text x="637" y="318" class="text">FP8 tensor</text>
<!-- Arrow -->
<path d="M 637 330 L 637 350" class="arrow"/>
<!-- Simplified computation -->
<rect x="567" y="350" width="140" height="35" class="computation" rx="8"/>
<text x="637" y="373" class="title" style="font-weight: 600;">...</text>
<!-- Arrow out of fused block -->
<path d="M 637 405 L 637 545" class="arrow"/>
<!-- Output -->
<rect x="562" y="545" width="150" height="50" class="hp" rx="8"/>
<text x="637" y="575" class="text">Output</text>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 340 280" width="100%" style="max-width: 340px;">
<defs>
<style>
.title { font: bold 12px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 10px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 9px sans-serif; fill: #555; }
.operator { font: bold 14px sans-serif; fill: #333; }
.equals { font: bold 14px sans-serif; fill: #333; }
/* Matrix colors */
.matrix-cell { fill: #bbdefb; stroke: #1565c0; stroke-width: 1; }
.highlight-row { fill: #c8e6c9; stroke: #1565c0; stroke-width: 1; }
.highlight-col { fill: #fff3e0; stroke: #1565c0; stroke-width: 1; }
.highlight-result { fill: #ffcc80; stroke: #1565c0; stroke-width: 1; }
</style>
</defs>
<!-- FIRST SCENARIO: A × B -->
<!-- Title -->
<text x="170" y="16" class="title">NN GEMM</text>
<!-- Matrix A (left) - centered -->
<g id="matrix-a1" transform="translate(40, 52)">
<text x="28" y="-14" class="label">A</text>
<!-- Row 0 -->
<rect x="0" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="0" width="14" height="14" class="matrix-cell"/>
<!-- Row 1 (highlighted - row i) -->
<rect x="0" y="14" width="14" height="14" class="highlight-row"/>
<rect x="14" y="14" width="14" height="14" class="highlight-row"/>
<rect x="28" y="14" width="14" height="14" class="highlight-row"/>
<rect x="42" y="14" width="14" height="14" class="highlight-row"/>
<!-- Row 2 -->
<rect x="0" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="28" width="14" height="14" class="matrix-cell"/>
<!-- Row 3 -->
<rect x="0" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="42" width="14" height="14" class="matrix-cell"/>
<!-- Label -->
<text x="28" y="70" class="small-text" style="fill: #388e3c; font-weight: bold; text-anchor: middle;">rowwise</text>
</g>
<!-- Matrix B (middle) - 4x4, centered -->
<g id="matrix-b1" transform="translate(136, 52)">
<text x="28" y="-14" class="label">B</text>
<!-- Row 0 -->
<rect x="0" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="0" width="14" height="14" class="highlight-col"/>
<rect x="42" y="0" width="14" height="14" class="matrix-cell"/>
<!-- Row 1 -->
<rect x="0" y="14" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="14" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="14" width="14" height="14" class="highlight-col"/>
<rect x="42" y="14" width="14" height="14" class="matrix-cell"/>
<!-- Row 2 -->
<rect x="0" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="28" width="14" height="14" class="highlight-col"/>
<rect x="42" y="28" width="14" height="14" class="matrix-cell"/>
<!-- Row 3 -->
<rect x="0" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="42" width="14" height="14" class="highlight-col"/>
<rect x="42" y="42" width="14" height="14" class="matrix-cell"/>
<!-- Label -->
<text x="28" y="70" class="small-text" style="fill: #f57c00; font-weight: bold; text-anchor: middle;">columnwise</text>
</g>
<!-- Result matrix 4x4 (right), centered -->
<g id="matrix-c1" transform="translate(232, 52)">
<text x="28" y="-14" class="label" style="font-weight: bold;">A×B</text>
<!-- Row 0 -->
<rect x="0" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="0" width="14" height="14" class="matrix-cell"/>
<!-- Row 1 -->
<rect x="0" y="14" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="14" width="14" height="14" class="matrix-cell"/>
<!-- Highlighted result cell -->
<rect x="28" y="14" width="14" height="14" class="highlight-result"/>
<rect x="42" y="14" width="14" height="14" class="matrix-cell"/>
<!-- Row 2 -->
<rect x="0" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="28" width="14" height="14" class="matrix-cell"/>
<!-- Row 3 -->
<rect x="0" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="42" width="14" height="14" class="matrix-cell"/>
</g>
<!-- SECOND SCENARIO: A × B^T -->
<!-- Title -->
<text x="170" y="152" class="title">TN GEMM</text>
<!-- Matrix A (left), centered -->
<g id="matrix-a2" transform="translate(40, 188)">
<text x="28" y="-14" class="label">A</text>
<!-- Row 0 -->
<rect x="0" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="0" width="14" height="14" class="matrix-cell"/>
<!-- Row 1 (highlighted - row i) -->
<rect x="0" y="14" width="14" height="14" class="highlight-row"/>
<rect x="14" y="14" width="14" height="14" class="highlight-row"/>
<rect x="28" y="14" width="14" height="14" class="highlight-row"/>
<rect x="42" y="14" width="14" height="14" class="highlight-row"/>
<!-- Row 2 -->
<rect x="0" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="28" width="14" height="14" class="matrix-cell"/>
<!-- Row 3 -->
<rect x="0" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="42" width="14" height="14" class="matrix-cell"/>
<!-- Label -->
<text x="28" y="70" class="small-text" style="fill: #388e3c; font-weight: bold; text-anchor: middle;">rowwise</text>
</g>
<!-- Matrix B (middle, shown with rowwise access for B^T) - 4x4, centered -->
<g id="matrix-b2" transform="translate(136, 188)">
<text x="28" y="-14" class="label">B</text>
<!-- Row 0 -->
<rect x="0" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="0" width="14" height="14" class="matrix-cell"/>
<!-- Row 1 (highlighted - rowwise access for transposed) -->
<rect x="0" y="14" width="14" height="14" class="highlight-row"/>
<rect x="14" y="14" width="14" height="14" class="highlight-row"/>
<rect x="28" y="14" width="14" height="14" class="highlight-row"/>
<rect x="42" y="14" width="14" height="14" class="highlight-row"/>
<!-- Row 2 -->
<rect x="0" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="28" width="14" height="14" class="matrix-cell"/>
<!-- Row 3 -->
<rect x="0" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="42" width="14" height="14" class="matrix-cell"/>
<!-- Label -->
<text x="28" y="70" class="small-text" style="fill: #388e3c; font-weight: bold; text-anchor: middle;">rowwise</text>
</g>
<!-- Result matrix 4x4 (right), centered -->
<g id="matrix-c2" transform="translate(232, 188)">
<text x="28" y="-14" class="label" style="font-weight: bold;">A×B<tspan baseline-shift="super" font-size="7">T</tspan></text>
<!-- Row 0 -->
<rect x="0" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="0" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="0" width="14" height="14" class="matrix-cell"/>
<!-- Row 1 -->
<rect x="0" y="14" width="14" height="14" class="matrix-cell"/>
<!-- Highlighted result cell -->
<rect x="14" y="14" width="14" height="14" class="highlight-result"/>
<rect x="28" y="14" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="14" width="14" height="14" class="matrix-cell"/>
<!-- Row 2 -->
<rect x="0" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="28" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="28" width="14" height="14" class="matrix-cell"/>
<!-- Row 3 -->
<rect x="0" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="14" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="28" y="42" width="14" height="14" class="matrix-cell"/>
<rect x="42" y="42" width="14" height="14" class="matrix-cell"/>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 700 220" width="100%" style="max-width: 700px;">
<defs>
<style>
.title { font: bold 14px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 11px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 10px sans-serif; fill: #555; text-anchor: middle; }
.cell-num { font: 9px sans-serif; fill: #1565c0; text-anchor: middle; dominant-baseline: middle; }
/* Matrix colors */
.matrix-cell { fill: #bbdefb; stroke: #1565c0; stroke-width: 1; }
</style>
</defs>
<!-- HOPPER SECTION (left) -->
<text x="175" y="25" class="title">FP8 tensor on Hopper</text>
<!-- Rowwise tensor -->
<g id="hopper-rowwise" transform="translate(50, 50)">
<text x="40" y="-8" class="label">rowwise</text>
<rect x="0" y="0" width="20" height="20" class="matrix-cell"/>
<text x="10" y="10" class="cell-num">0</text>
<rect x="20" y="0" width="20" height="20" class="matrix-cell"/>
<text x="30" y="10" class="cell-num">1</text>
<rect x="40" y="0" width="20" height="20" class="matrix-cell"/>
<text x="50" y="10" class="cell-num">2</text>
<rect x="60" y="0" width="20" height="20" class="matrix-cell"/>
<text x="70" y="10" class="cell-num">3</text>
<rect x="0" y="20" width="20" height="20" class="matrix-cell"/>
<text x="10" y="30" class="cell-num">4</text>
<rect x="20" y="20" width="20" height="20" class="matrix-cell"/>
<text x="30" y="30" class="cell-num">5</text>
<rect x="40" y="20" width="20" height="20" class="matrix-cell"/>
<text x="50" y="30" class="cell-num">6</text>
<rect x="60" y="20" width="20" height="20" class="matrix-cell"/>
<text x="70" y="30" class="cell-num">7</text>
<rect x="0" y="40" width="20" height="20" class="matrix-cell"/>
<text x="10" y="50" class="cell-num">8</text>
<rect x="20" y="40" width="20" height="20" class="matrix-cell"/>
<text x="30" y="50" class="cell-num">9</text>
<rect x="40" y="40" width="20" height="20" class="matrix-cell"/>
<text x="50" y="50" class="cell-num">10</text>
<rect x="60" y="40" width="20" height="20" class="matrix-cell"/>
<text x="70" y="50" class="cell-num">11</text>
</g>
<!-- Columnwise tensor (transposed) -->
<g id="hopper-colwise" transform="translate(200, 50)">
<text x="32" y="-8" class="label">columnwise</text>
<rect x="0" y="0" width="20" height="20" class="matrix-cell"/>
<text x="10" y="10" class="cell-num">0</text>
<rect x="20" y="0" width="20" height="20" class="matrix-cell"/>
<text x="30" y="10" class="cell-num">4</text>
<rect x="40" y="0" width="20" height="20" class="matrix-cell"/>
<text x="50" y="10" class="cell-num">8</text>
<rect x="0" y="20" width="20" height="20" class="matrix-cell"/>
<text x="10" y="30" class="cell-num">1</text>
<rect x="20" y="20" width="20" height="20" class="matrix-cell"/>
<text x="30" y="30" class="cell-num">5</text>
<rect x="40" y="20" width="20" height="20" class="matrix-cell"/>
<text x="50" y="30" class="cell-num">9</text>
<rect x="0" y="40" width="20" height="20" class="matrix-cell"/>
<text x="10" y="50" class="cell-num">2</text>
<rect x="20" y="40" width="20" height="20" class="matrix-cell"/>
<text x="30" y="50" class="cell-num">6</text>
<rect x="40" y="40" width="20" height="20" class="matrix-cell"/>
<text x="50" y="50" class="cell-num">10</text>
<rect x="0" y="60" width="20" height="20" class="matrix-cell"/>
<text x="10" y="70" class="cell-num">3</text>
<rect x="20" y="60" width="20" height="20" class="matrix-cell"/>
<text x="30" y="70" class="cell-num">7</text>
<rect x="40" y="60" width="20" height="20" class="matrix-cell"/>
<text x="50" y="70" class="cell-num">11</text>
</g>
<!-- Separator -->
<line x1="350" y1="20" x2="350" y2="200" stroke="#bdbdbd" stroke-width="1" stroke-dasharray="4,4"/>
<!-- BLACKWELL SECTION (right) -->
<text x="525" y="25" class="title">FP8 tensor on Blackwell</text>
<!-- Single tensor for both usages -->
<g id="blackwell-both" transform="translate(455, 50)">
<text x="70" y="-8" class="label">rowwise and columnwise</text>
<rect x="30" y="0" width="20" height="20" class="matrix-cell"/>
<text x="40" y="10" class="cell-num">0</text>
<rect x="50" y="0" width="20" height="20" class="matrix-cell"/>
<text x="60" y="10" class="cell-num">1</text>
<rect x="70" y="0" width="20" height="20" class="matrix-cell"/>
<text x="80" y="10" class="cell-num">2</text>
<rect x="90" y="0" width="20" height="20" class="matrix-cell"/>
<text x="100" y="10" class="cell-num">3</text>
<rect x="30" y="20" width="20" height="20" class="matrix-cell"/>
<text x="40" y="30" class="cell-num">4</text>
<rect x="50" y="20" width="20" height="20" class="matrix-cell"/>
<text x="60" y="30" class="cell-num">5</text>
<rect x="70" y="20" width="20" height="20" class="matrix-cell"/>
<text x="80" y="30" class="cell-num">6</text>
<rect x="90" y="20" width="20" height="20" class="matrix-cell"/>
<text x="100" y="30" class="cell-num">7</text>
<rect x="30" y="40" width="20" height="20" class="matrix-cell"/>
<text x="40" y="50" class="cell-num">8</text>
<rect x="50" y="40" width="20" height="20" class="matrix-cell"/>
<text x="60" y="50" class="cell-num">9</text>
<rect x="70" y="40" width="20" height="20" class="matrix-cell"/>
<text x="80" y="50" class="cell-num">10</text>
<rect x="90" y="40" width="20" height="20" class="matrix-cell"/>
<text x="100" y="50" class="cell-num">11</text>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 850 660" width="850" height="660">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
/* Arrows */
.arrow { stroke: #616161; stroke-width: 2; fill: none; marker-end: url(#arrowhead); }
/* Phase labels */
.phase-forward { fill: #1565c0; font-weight: 600; }
.phase-backward { fill: #c62828; font-weight: 600; }
/* All-gather operation */
.allgather {
fill: #e1f5fe;
stroke: #039be5;
stroke-width: 2;
}
</style>
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- Main Title -->
<text x="425" y="30" class="title" style="font-size: 18px; font-weight: 700; fill: #1a1a1a;">All-Gather of Quantized Tensors (one scenario)</text>
<!-- Section 1: Input Tensor -->
<text x="425" y="70" class="section-title" style="fill: #212121; font-weight: 600;">Input Tensor quantized all-gather</text>
<!-- Forward phase label -->
<text x="20" y="128" class="text phase-forward" style="text-anchor: start;">FWD:</text>
<!-- High Precision Input -->
<rect x="90" y="100" width="130" height="55" class="hp" rx="6"/>
<text x="155" y="123" class="text">High Precision</text>
<text x="155" y="140" class="text">Tensor</text>
<!-- Arrow -->
<path d="M 220 127 L 250 127" class="arrow"/>
<!-- Quantize -->
<rect x="250" y="100" width="100" height="55" class="quantize" rx="6"/>
<text x="300" y="132" class="text">Quantize</text>
<!-- Arrow -->
<path d="M 350 127 L 380 127" class="arrow"/>
<!-- Rowwise Quantized Tensor -->
<rect x="380" y="100" width="130" height="55" class="fp8" rx="6"/>
<text x="445" y="123" class="text">Rowwise</text>
<text x="445" y="140" class="text">Quantized</text>
<!-- Arrow -->
<path d="M 510 127 L 540 127" class="arrow"/>
<!-- All-Gather -->
<rect x="540" y="100" width="110" height="55" class="allgather" rx="6"/>
<text x="595" y="132" class="text">All-Gather</text>
<!-- Arrow -->
<path d="M 650 127 L 680 127" class="arrow"/>
<!-- Continue indicator -->
<text x="705" y="132" class="text">...</text>
<!-- Backward phase label -->
<text x="20" y="228" class="text phase-backward" style="text-anchor: start;">BWD:</text>
<!-- Arrow from Quantize to Columnwise (showing separate quantization) -->
<path d="M 300 155 L 300 180 L 380 227" class="arrow"/>
<!-- Columnwise Quantized Tensor -->
<rect x="380" y="200" width="130" height="55" class="fp8" rx="6"/>
<text x="445" y="223" class="text">Columnwise</text>
<text x="445" y="240" class="text">Quantized</text>
<!-- Arrow -->
<path d="M 510 227 L 540 227" class="arrow"/>
<!-- All-Gather -->
<rect x="540" y="200" width="110" height="55" class="allgather" rx="6"/>
<text x="595" y="232" class="text">All-Gather</text>
<!-- Arrow -->
<path d="M 650 227 L 680 227" class="arrow"/>
<!-- Continue indicator -->
<text x="705" y="232" class="text">...</text>
<!-- Divider Line -->
<line x1="30" y1="305" x2="820" y2="305" stroke="#bdbdbd" stroke-width="3"/>
<!-- Section 2: Gradient Tensor -->
<text x="425" y="340" class="section-title" style="fill: #212121; font-weight: 600;">Gradient Tensor quantized all-gather</text>
<!-- Backward phase label -->
<text x="20" y="413" class="text phase-backward" style="text-anchor: start;">BWD:</text>
<!-- High Precision Gradient -->
<rect x="90" y="385" width="130" height="55" class="hp" rx="6"/>
<text x="155" y="408" class="text">High Precision</text>
<text x="155" y="425" class="text">Tensor</text>
<!-- Arrow -->
<path d="M 220 412 L 250 412" class="arrow"/>
<!-- Quantize -->
<rect x="250" y="385" width="100" height="55" class="quantize" rx="6"/>
<text x="300" y="417" class="text">Quantize</text>
<!-- Arrow to Columnwise -->
<path d="M 350 397 L 380 377" class="arrow"/>
<!-- Columnwise Quantized Tensor -->
<rect x="380" y="355" width="130" height="45" class="fp8" rx="6"/>
<text x="445" y="382" class="text">Col. Quantized</text>
<!-- Arrow to Rowwise -->
<path d="M 350 427 L 380 447" class="arrow"/>
<!-- Rowwise Quantized Tensor -->
<rect x="380" y="425" width="130" height="45" class="fp8" rx="6"/>
<text x="445" y="452" class="text">Row. Quantized</text>
<!-- Arrow from Columnwise to All-Gather -->
<path d="M 510 377 L 540 412" class="arrow"/>
<!-- Arrow from Rowwise to All-Gather -->
<path d="M 510 447 L 540 412" class="arrow"/>
<!-- All-Gather (single, shared) -->
<rect x="540" y="390" width="110" height="45" class="allgather" rx="6"/>
<text x="595" y="417" class="text">All-Gather</text>
<!-- Arrow -->
<path d="M 650 412 L 680 412" class="arrow"/>
<!-- Continue indicator -->
<text x="705" y="417" class="text">...</text>
<!-- Legend -->
<g transform="translate(100, 515)">
<rect x="0" y="0" width="80" height="40" rx="5" class="hp"/>
<text x="95" y="23" class="text" style="text-anchor: start;">High Precision (FP32/BF16/FP16)</text>
<rect x="380" y="0" width="80" height="40" rx="5" class="fp8"/>
<text x="475" y="23" class="text" style="text-anchor: start;">Lower Precision (FP8, etc.)</text>
<rect x="0" y="55" width="80" height="40" rx="5" class="quantize"/>
<text x="95" y="78" class="text" style="text-anchor: start;">Quantization</text>
<rect x="380" y="55" width="80" height="40" rx="5" class="allgather"/>
<text x="475" y="78" class="text" style="text-anchor: start;">All-Gather</text>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 850 750" width="850" height="750">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
/* Arrows */
.arrow { stroke: #616161; stroke-width: 2; fill: none; marker-end: url(#arrowhead); }
.arrow-dashed { stroke: #616161; stroke-width: 2; fill: none; stroke-dasharray: 5,5; marker-end: url(#arrowhead); }
/* Phase labels */
.phase-forward { fill: #1565c0; font-weight: 600; }
.phase-backward { fill: #c62828; font-weight: 600; }
</style>
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- OPTION 1: Both usages in forward -->
<text x="425" y="35" class="section-title" style="fill: #212121; font-weight: 600;">Option 1: Quantize both usages in forward</text>
<!-- Forward phase label -->
<text x="20" y="113" class="text phase-forward" style="text-anchor: start;">FORWARD:</text>
<!-- High Precision Tensor -->
<rect x="130" y="85" width="130" height="55" class="hp" rx="6"/>
<text x="195" y="108" class="text">High Precision</text>
<text x="195" y="125" class="text">Tensor</text>
<!-- Arrow -->
<path d="M 260 112 L 290 112" class="arrow"/>
<!-- Quantize -->
<rect x="290" y="85" width="160" height="55" class="quantize" rx="6"/>
<text x="370" y="117" class="text">Quantize</text>
<!-- Arrow to FP8 -->
<path d="M 450 112 L 480 112" class="arrow"/>
<!-- Quantized Tensor (for forward) -->
<rect x="480" y="85" width="130" height="55" class="fp8" rx="6"/>
<text x="545" y="108" class="text">Quantized</text>
<text x="545" y="125" class="text">Rowwise</text>
<!-- Backward phase label -->
<text x="20" y="163" class="text phase-backward" style="text-anchor: start;">BACKWARD:</text>
<!-- Arrow to FP8 columnwise -->
<path d="M 450 130 L 480 152" class="arrow"/>
<!-- Quantized columnwise (saved for backward) -->
<rect x="480" y="150" width="130" height="55" class="fp8" rx="6"/>
<text x="545" y="173" class="text">Quantized</text>
<text x="545" y="190" class="text">Columnwise</text>
<!-- Divider Line -->
<line x1="30" y1="220" x2="820" y2="220" stroke="#bdbdbd" stroke-width="3"/>
<!-- OPTION 2: Separate approach -->
<text x="425" y="250" class="section-title" style="fill: #212121; font-weight: 600;">Option 2: Separate Quantizations (quantize when needed)</text>
<!-- Forward phase -->
<text x="20" y="293" class="text phase-forward" style="text-anchor: start;">FORWARD:</text>
<!-- High Precision Tensor (shared) -->
<rect x="130" y="265" width="130" height="55" class="hp" rx="6"/>
<text x="195" y="288" class="text">High Precision</text>
<text x="195" y="305" class="text">Tensor</text>
<!-- Arrow to Quantize -->
<path d="M 260 292 L 290 292" class="arrow"/>
<!-- Quantize only -->
<rect x="290" y="265" width="120" height="55" class="quantize" rx="6"/>
<text x="350" y="297" class="text">Quantize</text>
<!-- Arrow to FP8 -->
<path d="M 410 292 L 440 292" class="arrow"/>
<!-- Quantized Tensor -->
<rect x="440" y="265" width="130" height="55" class="fp8" rx="6"/>
<text x="505" y="288" class="text">Quantized</text>
<text x="505" y="305" class="text">Rowwise</text>
<!-- Dashed line showing it's the same tensor -->
<path d="M 195 320 L 195 360" class="arrow-dashed"/>
<!-- Backward phase -->
<text x="20" y="393" class="text phase-backward" style="text-anchor: start;">BACKWARD:</text>
<!-- High Precision Tensor (same, reused) -->
<rect x="130" y="360" width="130" height="55" class="hp" rx="6"/>
<text x="195" y="383" class="text">High Precision</text>
<text x="195" y="400" class="text">Tensor</text>
<!-- Arrow to Quantize -->
<path d="M 260 387 L 290 387" class="arrow"/>
<!-- Quantize (in backward) -->
<rect x="290" y="360" width="160" height="55" class="quantize" rx="6"/>
<text x="370" y="392" class="text">Quantize</text>
<!-- Arrow to FP8 columnwise -->
<path d="M 450 387 L 480 387" class="arrow"/>
<!-- Quantized columnwise -->
<rect x="480" y="360" width="130" height="55" class="fp8" rx="6"/>
<text x="545" y="383" class="text">Quantized</text>
<text x="545" y="400" class="text">Columnwise</text>
<!-- Divider Line -->
<line x1="30" y1="430" x2="820" y2="430" stroke="#bdbdbd" stroke-width="3"/>
<!-- OPTION 3: Convert from rowwise to columnwise -->
<text x="425" y="460" class="section-title" style="fill: #212121; font-weight: 600;">Option 3: Convert Rowwise to Columnwise in Backward (reuse saved tensor)</text>
<!-- Forward phase -->
<text x="20" y="503" class="text phase-forward" style="text-anchor: start;">FORWARD:</text>
<!-- High Precision Tensor -->
<rect x="130" y="475" width="130" height="55" class="hp" rx="6"/>
<text x="195" y="498" class="text">High Precision</text>
<text x="195" y="515" class="text">Tensor</text>
<!-- Arrow to Quantize -->
<path d="M 260 502 L 290 502" class="arrow"/>
<!-- Quantize only -->
<rect x="290" y="475" width="120" height="55" class="quantize" rx="6"/>
<text x="350" y="507" class="text">Quantize</text>
<!-- Arrow to FP8 -->
<path d="M 410 502 L 440 502" class="arrow"/>
<!-- Quantized Tensor (saved) -->
<rect x="440" y="475" width="130" height="55" class="fp8" rx="6"/>
<text x="505" y="498" class="text">Quantized</text>
<text x="505" y="515" class="text">Rowwise</text>
<!-- Dashed line showing tensor is saved -->
<path d="M 505 530 L 505 547 L 195 547 L 195 565" class="arrow-dashed"/>
<!-- Backward phase -->
<text x="20" y="598" class="text phase-backward" style="text-anchor: start;">BACKWARD:</text>
<!-- Quantized Tensor (reused) -->
<rect x="130" y="565" width="130" height="55" class="fp8" rx="6"/>
<text x="195" y="588" class="text">Quantized</text>
<text x="195" y="605" class="text">Rowwise</text>
<!-- Arrow to Make Columnwise -->
<path d="M 260 592 L 290 592" class="arrow"/>
<!-- Make Columnwise operation -->
<rect x="290" y="565" width="120" height="55" class="quantize" rx="6"/>
<text x="350" y="585" class="text">Make</text>
<text x="350" y="602" class="text">Columnwise</text>
<!-- Arrow to FP8 columnwise -->
<path d="M 410 592 L 440 592" class="arrow"/>
<!-- Quantized columnwise -->
<rect x="440" y="565" width="130" height="55" class="fp8" rx="6"/>
<text x="505" y="588" class="text">Quantized</text>
<text x="505" y="605" class="text">Columnwise</text>
<!-- Legend -->
<g transform="translate(150, 660)">
<rect x="0" y="0" width="70" height="35" rx="5" class="hp"/>
<text x="85" y="21" class="small-text" style="text-anchor: start;">High Precision (FP32/BF16/FP16)</text>
<rect x="380" y="0" width="70" height="35" rx="5" class="fp8"/>
<text x="465" y="21" class="small-text" style="text-anchor: start;">Lower Precision (FP8, etc.)</text>
<rect x="0" y="45" width="70" height="35" rx="5" class="quantize"/>
<text x="85" y="66" class="small-text" style="text-anchor: start;">Quantization / Make Columnwise</text>
</g>
</svg>
# START_MEMORY_USAGE_1
Tensors in memory:
Shape: (1024, 1024), Dtype: bfloat16, Size: 2048.0 KB
Shape: (1024, 1024), Dtype: bfloat16, Size: 2048.0 KB
Total from all live arrays: 4.00 MB
# END_MEMORY_USAGE_1
Processing events...
Generated:
No reports were generated
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
print("# START_MEMORY_USAGE_1")
import jax
import jax.numpy as jnp
from transformer_engine.jax.flax import DenseGeneral
key = jax.random.PRNGKey(0)
jax.clear_caches()
# Initialize layer with BF16 parameters
layer = DenseGeneral(features=1024, dtype=jnp.bfloat16)
x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16)
var_collect = layer.init(key, x)
@jax.jit
def loss_fn(var_collect, x):
output = layer.apply(var_collect, x)
return output.sum()
# Trace the backward pass - this allocates saved tensors
_, backward_fn = jax.vjp(loss_fn, var_collect, x)
del x
print("Tensors in memory:")
total_bytes = 0
for arr in jax.live_arrays():
total_bytes += arr.nbytes
if arr.nbytes > 200000: # do not count small tensors
print(f" Shape: {arr.shape}, Dtype: {arr.dtype}, Size: {arr.nbytes / 1024:.1f} KB")
print(f" Total from all live arrays: {total_bytes / (1024**2):.2f} MB")
print("# END_MEMORY_USAGE_1")
# START_MEMORY_USAGE_1
Memory usage after forward pass: 6.00 MB
# END_MEMORY_USAGE_1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
cc = torch.cuda.get_device_capability()
assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)"
print("# START_MEMORY_USAGE_1")
import torch
import transformer_engine.pytorch as te
def measure_memory():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
init_memory = torch.cuda.memory_allocated()
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda")
out = layer(inp)
del inp # Input is saved by model for backward, not by user script
mem_after_forward = torch.cuda.memory_allocated() - init_memory
return mem_after_forward
# Warmup run
measure_memory()
# Actual measurement
mem_after_forward = measure_memory()
print(f"Memory usage after forward pass: {mem_after_forward/1024**2:.2f} MB")
# END_MEMORY_USAGE_1
print("# END_MEMORY_USAGE_1")
# START_MEMORY_USAGE_2
Tensors in memory:
Shape: (1024, 1024), Dtype: float8_e4m3fn, Size: 1024.0 KB
Shape: (1024, 1024), Dtype: float8_e4m3fn, Size: 1024.0 KB
Shape: (1024, 1024), Dtype: bfloat16, Size: 2048.0 KB
Total from all live arrays: 4.02 MB
# END_MEMORY_USAGE_2
Processing events...
Generated:
No reports were generated
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
print("# START_MEMORY_USAGE_2")
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import DenseGeneral
from transformer_engine.common.recipe import DelayedScaling
key = jax.random.PRNGKey(0)
recipe = DelayedScaling()
jax.clear_caches()
# Initialize layer with BF16 parameters (outside autocast)
layer = DenseGeneral(features=1024, dtype=jnp.bfloat16)
x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16)
# Forward and backward pass with FP8 compute
with te.autocast(enabled=True, recipe=recipe):
var_collect = layer.init(key, x)
@jax.jit
def loss_fn(var_collect, x):
output = layer.apply(var_collect, x)
return output.sum()
# Trace the backward pass - this allocates saved tensors
_, backward_fn = jax.vjp(loss_fn, var_collect, x)
del x
print("Tensors in memory:")
total_bytes = 0
for arr in jax.live_arrays():
total_bytes += arr.nbytes
if arr.nbytes > 200000: # do not count small tensors
print(f" Shape: {arr.shape}, Dtype: {arr.dtype}, Size: {arr.nbytes / 1024:.1f} KB")
print(f" Total from all live arrays: {total_bytes / (1024**2):.2f} MB")
print("# END_MEMORY_USAGE_2")
# START_MEMORY_USAGE_2
Memory after forward pass: 6.02 MB
# END_MEMORY_USAGE_2
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
cc = torch.cuda.get_device_capability()
assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)"
print("# START_MEMORY_USAGE_2")
import torch
import transformer_engine.pytorch as te
def measure_memory():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
init_memory = torch.cuda.memory_allocated()
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True):
out = layer(inp)
del inp # Input is saved by model for backward, not by user script
mem_after_forward = torch.cuda.memory_allocated() - init_memory
return mem_after_forward
# Warmup run
measure_memory()
# Actual measurement
mem_after_forward = measure_memory()
print(f"Memory after forward pass: {mem_after_forward/1024**2:.2f} MB")
# END_MEMORY_USAGE_2
print("# END_MEMORY_USAGE_2")
# START_MEMORY_USAGE_3
Memory after forward pass: 3.02 MB
# END_MEMORY_USAGE_3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
cc = torch.cuda.get_device_capability()
assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)"
print("# START_MEMORY_USAGE_3")
import torch
import transformer_engine.pytorch as te
def measure_memory():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
init_memory = torch.cuda.memory_allocated()
# FP8 inference with FP8 weights
with te.quantized_model_init(enabled=True), torch.no_grad():
layer_fp8 = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
with torch.no_grad():
inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True):
out = layer_fp8(inp)
del inp # Input is not saved by model for backward in inference
mem_after_forward = torch.cuda.memory_allocated() - init_memory
return mem_after_forward
# Warmup run
measure_memory()
# Actual measurement
mem_after_forward = measure_memory()
print(f"Memory after forward pass: {mem_after_forward/1024**2:.2f} MB")
# END_MEMORY_USAGE_3
print("# END_MEMORY_USAGE_3")
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Performance Considerations
===================================
.. _handling_transposes:
Handling transposes
-------------------
In the last chapter we demonstrated that for FP8 on Hopper architecture,
some tensors need to be physically transposed in memory to perform needed GEMMs.
Dealing with transposes in Transformer low precision training is a bit tricky.
Let's start by introducing the concept of *tensor usages*.
**Tensor usages**
Each quantized tensor may have two usages:
- *rowwise usage* -- which is used for matrix multiplication, when the consecutive elements in row are accessed,
- *columnwise usage* -- which is used for matrix multiplication, when the consecutive elements in column are accessed,
To understand what access of consecutive elements means, let's consider two matrices ``A`` and ``B``
and analyze how their elements are accessed during multiplication.
For NN (non-transposed, non-transposed) multiplication ``C = A * B``, the formula is ``C_ij = sum_k(A_ik * B_kj)``.
To compute element ``C_ij``, we iterate over the i-th row of ``A`` (elements ``A_i0, A_i1, ...``)
and the j-th column of ``B`` (elements ``B_0j, B_1j, ...``). Thus, ``A`` is accessed rowwise
and ``B`` is accessed columnwise.
For NT (non-transposed, transposed) multiplication ``C = A * B^T``, the formula changes to ``C_ij = sum_k(A_ik * B_jk)``.
Now we iterate over the i-th row of ``A`` and the j-th row of ``B`` (elements ``B_j0, B_j1, ...``).
Both tensors are accessed rowwise.
The figure below illustrates these access patterns:
.. figure:: img/gemm_access_pattern.svg
:align: center
:width: 60%
:alt: Matrix multiplication access pattern showing rowwise access for first tensor and columnwise access for second tensor
Figure 1: Access patterns in matrix multiplication for matrices in ``A * B`` and ``A * B^T`` operations.
Based on the visualization above, we can derive general rules for when each matrix
is accessed in rowwise or columnwise fashion. The key insight is that:
- The **first tensor** in a matrix multiplication is accessed along its rows (rowwise) when non-transposed,
or along its columns (columnwise) when transposed.
- The **second tensor** follows the opposite pattern: columnwise when non-transposed, rowwise when transposed.
.. table:: Table 1: Summary of tensor access patterns based on transpose state.
:align: center
+------------------+--------------+---------------+
| | First tensor | Second tensor |
+------------------+--------------+---------------+
| Non-transposed | rowwise | columnwise |
+------------------+--------------+---------------+
| Transposed | columnwise | rowwise |
+------------------+--------------+---------------+
**Input, weight and output gradient usages**
Now let's apply these rules to a Linear layer. During training, a Linear layer performs
three GEMM operations: one in the forward pass and two in the backward pass.
.. table:: Table 2: Tensor access patterns for GEMM operations in a Linear layer during training.
:align: center
+-------------------+-------------------------------------+---------------------------+---------------------------+
| GEMM | Formula | First tensor usage | Second tensor usage |
+===================+=====================================+===========================+===========================+
| Forward | ``output = input * weight^T`` | input: rowwise | weight: rowwise |
+-------------------+-------------------------------------+---------------------------+---------------------------+
| Weight gradient | ``wgrad = output_grad^T * input`` | output_grad: columnwise | input: columnwise |
+-------------------+-------------------------------------+---------------------------+---------------------------+
| Input gradient | ``dgrad = output_grad * weight`` | output_grad: rowwise | weight: columnwise |
+-------------------+-------------------------------------+---------------------------+---------------------------+
An important observation is that the **forward pass uses only rowwise tensors** - both input
and weight are accessed rowwise.
The backward pass introduces columnwise access. For weight gradient, both output gradient and input
are accessed columnwise. For input gradient, output gradient is rowwise while weight is columnwise.
As a result, each tensor (input, weight, output gradient) needs both rowwise and columnwise
usages during training. This has implications for memory layout and transpose operations.
**Architecture differences**
The physical memory layout requirements for rowwise and columnwise usages differ between architectures
and recipes. For FP8 tensors:
- *Hopper*: cannot efficiently access elements in columnwise fashion, so columnwise tensors need to be physically transposed in memory. Note that higher precision formats (BF16/FP16) do not have this limitation.
- *Blackwell*: supports columnwise access natively, so no transpose is needed.
We will see that for most of the recipes and devices, rowwise usage and columnwise usage need different tensors.
Thus by *rowwise tensor* and *columnwise tensor* we mean tensors that are used in rowwise and columnwise usages respectively.
.. figure:: img/hopper_vs_blackwell_layout.svg
:align: center
:alt: Comparison of rowwise and columnwise tensor layouts on Blackwell vs Hopper
Figure 2: On Blackwell, rowwise and columnwise usages share the same memory layout.
On Hopper, columnwise usage requires a physical transpose.
**Quantization fusions**
This section is relevant only for recipes for which columnwise tensors
are different from rowwise tensors.
Note that performing rowwise and columnwise quantization at the same time
enables some fusions, which usually lead to better performance.
We showcase 3 example scenarios of producing quantized tensors in rowwise and columnwise usages,
TE will use best possible fusion for given recipe and TE module configuration:
1. *Computation of quantized tensor in both rowwise and columnwise usages in a single kernel in forward pass*.
This is the fastest one,
but since the columnwise usage is saved for backward pass, it may lead to increased memory usage,
if the high precision tensor also needs to be saved for backward - for example if it is the attention output which is saved anyway.
2. *Computation of quantized tensor in rowwise usage in forward pass and fused quantization to produce columnwise usage in backward pass*.
This is usually slower than the previous one, since high precision tensor needs to be read twice.
It is used for example when high precision tensor is gathered both in forward and in backward
and quantized tensor gather is not implemented for such recipe.
3. *Computation of quantized tensor in rowwise usage in forward pass and transpose to columnwise usage in backward pass*.
It is more memory efficient than Option 1, but not all recipes can utilize it (otherwise
the quantization accuracy would drop due to double quantization errors).
Transformer Engine chooses the best possible fusion internally taking the recipe and the operation into account.
.. raw:: html
:file: img/transpose_fusion.svg
*Figure 3: Three scenarios of producing quantized tensors in rowwise and columnwise usages.*
Memory usage
------------
This section discusses memory usage in low precision training.
Contrary to intuition, FP8 training does not always reduce memory compared to BF16/FP16.
*Master weights*
Transformer Engine by default stores weights in high precision and quantizes them to low precision before each GEMM.
Moreover, one can specify which high precision should be used to store the weights in the
model (FP32/BF16/FP16) -- or choose not to store high precision weights in the model at all.
There are multiple scenarios to consider, three of them are listed below:
1. model weights are in FP32, quantized to low precision before each GEMM,
2. model weights are in BF16/FP16, quantized to low precision before each GEMM, master weights in optimizer are in FP32.
3. model weights are stored directly in low precision, and master weights in optimizer are in FP32.
Note that each of these scenarios may have different memory footprint.
*Activations saved for backward*
Unlike weights, activations do not require a high precision copy for optimizer updates.
As shown in Table 2, the input needs rowwise usage in forward and columnwise usage
for weight gradient computation in backward — so it must be saved between passes.
The memory impact depends on which scenario from Figure 3 is used.
Additionally, on architectures where rowwise and columnwise usage tensors share the same memory layout
(e.g., FP8 on Blackwell, as shown in Figure 2), a single quantized tensor serves both usages,
reducing memory overhead compared to architectures requiring separate tensors.
Output gradients, on the other hand, are computed during backward and do not need to be saved —
both rowwise and columnwise usages are produced on the fly as needed.
The FP8 examples below are analyzed on Hopper (SM90) or Ada (SM89) architecture, where rowwise
and columnwise tensors require separate memory layouts.
.. tabs::
.. tab:: PyTorch
**1. Baseline: high precision forward pass**
Let's start with a forward pass in higher precision to establish a baseline.
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89 (Ada) or SM90 (Hopper)
</div>
.. literalinclude:: memory_usage_1_pytorch.py
:language: python
:start-after: # START_MEMORY_USAGE_1
:end-before: # END_MEMORY_USAGE_1
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: memory_usage_1_pytorch.out
:language: text
:start-after: # START_MEMORY_USAGE_1
:end-before: # END_MEMORY_USAGE_1
Layer size is ``1024 * 1024 * 2 (2 bytes per parameter) = 2MB``.
Memory after forward pass is ``2 MB (weight) + 2 MB (input saved for backward) + 2 MB (output) = 6 MB``.
**2. FP8 training with model weights in BF16**
Now let's see the memory usage in FP8 training with high precision weights.
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89 (Ada) or SM90 (Hopper)
</div>
.. literalinclude:: memory_usage_2_pytorch.py
:language: python
:start-after: # START_MEMORY_USAGE_2
:end-before: # END_MEMORY_USAGE_2
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: memory_usage_2_pytorch.out
:language: text
:start-after: # START_MEMORY_USAGE_2
:end-before: # END_MEMORY_USAGE_2
Total memory usage is ``2 MB (weight) + 1 MB (weight in FP8) + 1 MB (input in FP8 saved for backward) + 2 MB (output) = 6 MB``.
**3. FP8 inference with model weights stored directly in low precision**
For inference scenarios, model weights can be stored directly in low precision. Since we are only
performing forward passes without gradient updates, master weights in high precision are not needed.
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89 (Ada) or SM90 (Hopper)
</div>
.. literalinclude:: memory_usage_3_pytorch.py
:language: python
:start-after: # START_MEMORY_USAGE_3
:end-before: # END_MEMORY_USAGE_3
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: memory_usage_3_pytorch.out
:language: text
:start-after: # START_MEMORY_USAGE_3
:end-before: # END_MEMORY_USAGE_3
Total memory usage is ``1 MB (weight in FP8) + 2 MB (output) = 3 MB``.
This is lower than the BF16 baseline (6 MB) since no copies are saved for backward in inference mode.
**4. Saving original input instead of quantized**
By default, TE saves the columnwise quantized input for the backward pass (needed for weight gradient).
However, when the high precision input is already being saved (e.g., for a residual connection),
keeping an additional quantized copy wastes memory.
The ``save_original_input=True`` option tells the layer to reference the original high precision input
instead of caching a separate quantized copy. The input is re-quantized during backward when needed.
Below is an example with a residual block where input is kept for the addition:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89 (Ada) or SM90 (Hopper)
</div>
.. literalinclude:: save_original_input_pytorch.py
:language: python
:start-after: # START_SAVE_ORIGINAL_INPUT
:end-before: # END_SAVE_ORIGINAL_INPUT
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: save_original_input_pytorch.out
:language: text
:start-after: # START_SAVE_ORIGINAL_INPUT
:end-before: # END_SAVE_ORIGINAL_INPUT
.. tab:: JAX
**1. Baseline: high precision forward pass**
Let's start with a forward pass in higher precision to establish a baseline.
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89 (Ada) or SM90 (Hopper)
</div>
.. literalinclude:: memory_usage_1_jax.py
:language: python
:start-after: # START_MEMORY_USAGE_1
:end-before: # END_MEMORY_USAGE_1
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: memory_usage_1_jax.out
:language: text
:start-after: # START_MEMORY_USAGE_1
:end-before: # END_MEMORY_USAGE_1
Layer size is ``1024 * 1024 * 2 (2 bytes per parameter) = 2MB``.
Memory after forward pass is ``2 MB (weight) + 2 MB (input saved for backward) = 4 MB``.
**2. FP8 training with master weights in BF16**
Now let's see the memory usage in FP8 training with high precision weights.
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89 (Ada) or SM90 (Hopper)
</div>
.. literalinclude:: memory_usage_2_jax.py
:language: python
:start-after: # START_MEMORY_USAGE_2
:end-before: # END_MEMORY_USAGE_2
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: memory_usage_2_jax.out
:language: text
:start-after: # START_MEMORY_USAGE_2
:end-before: # END_MEMORY_USAGE_2
Memory after forward pass is ``2 MB (weight in BF16) + 1 MB (input in FP8) + 1 MB (weight in FP8) = 4 MB``.
Fused layers
------------
Transformer Engine provides fused layers such as ``LayerNormLinear`` (``LayerNormDenseGeneral`` in JAX) and ``LayerNormMLP``
that enable kernel fusion optimizations. One key optimization is fusing layer normalization
with quantization.
In a typical Transformer architecture, LayerNorm precedes a Linear layer. Without fusion,
the LayerNorm outputs in high precision, and the Linear layer must then quantize this input before
performing the GEMM — adding overhead. With ``LayerNormLinear``, these operations are fused
into a single kernel: the LayerNorm output is quantized directly, eliminating the separate
quantization step and reducing memory movement.
.. raw:: html
:file: img/fused_layers.svg
*Figure 4: Comparison of separate LayerNorm and Linear layers versus fused LayerNormLinear layer, showing reduced quantization overhead.*
Let's see how we can use fused layers in different frameworks.
.. tabs::
.. tab:: PyTorch
In PyTorch, Transformer Engine provides fused layers like ``LayerNormLinear`` and ``LayerNormMLP``.
These layers combine normalization and linear operations with optimized quantization.
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada, Hopper, Blackwell, or newer)
</div>
.. literalinclude:: fused_layers_pytorch.py
:language: python
:start-after: # START_FUSED_LAYERS
:end-before: # END_FUSED_LAYERS
The fused ``LayerNormLinear`` layer is particularly efficient in FP8 training because
it avoids an intermediate quantization step. The LayerNorm output is directly quantized
for the GEMM operation, reducing memory movement and improving performance.
.. tab:: JAX
In JAX, Transformer Engine provides fused layers like ``LayerNormDenseGeneral`` and ``LayerNormMLP``.
These layers combine normalization and dense operations with optimized quantization.
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada, Hopper, Blackwell, or newer)
</div>
.. literalinclude:: fused_layers_jax.py
:language: python
:start-after: # START_FUSED_LAYERS
:end-before: # END_FUSED_LAYERS
The fused ``LayerNormDenseGeneral`` layer is particularly efficient in FP8 training because
it avoids an intermediate quantization step. The LayerNorm output is directly quantized
for the GEMM operation, reducing memory movement and improving performance.
Distributed training
--------------------
Transformer Engine handles collective operations internally, so users typically don't need to manage
the interaction between communication and low precision computation.
Recall that each Linear layer involves six tensors: weight, input, output, and their gradients.
Of these, output and gradients are returned in high precision, and weights are generally not
communicated (except in FSDP, which is outside the scope of this section). This leaves two
tensors where low precision communication matters: **input** and **output gradient**.
For sequence parallelism, TE supports all-gather of quantized tensors. This provides several benefits:
1. *Reduced memory usage* — no need to store high precision tensors for backward pass.
2. *Reduced communication* — smaller tensors mean less data to transfer.
3. *Parallelized quantization* — quantization work is distributed across GPUs.
Support varies by recipe — for example, columnwise quantized all-gather is not available
for all configurations.
The figure below illustrates one possible all-gather scenario for input and output gradient tensors.
Actual behavior depends on the recipe and module configuration.
.. raw:: html
:file: img/sequence_parallel_quantization.svg
*Figure 5: All-gather of quantized tensors for input and gradient tensors.
This is one possible scenario — actual behavior varies depending on the recipe and module configuration.*
# START_SAVE_ORIGINAL_INPUT
save_original_input=False: 25.0 MB
save_original_input=True: 24.0 MB
# END_SAVE_ORIGINAL_INPUT
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
cc = torch.cuda.get_device_capability()
assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)"
print("# START_SAVE_ORIGINAL_INPUT")
# START_SAVE_ORIGINAL_INPUT
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Float8CurrentScaling
recipe = Float8CurrentScaling()
def residual_block(layer, inp):
"""Residual connection: input is saved for addition after linear."""
out = layer(inp)
return out + inp # inp must be kept for this addition
def measure_memory(use_save_original):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
layer = te.Linear(
1024, 1024, params_dtype=torch.bfloat16, save_original_input=use_save_original
)
inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda", requires_grad=True)
with te.autocast(enabled=True, recipe=recipe):
out = residual_block(layer, inp)
out.sum().backward()
return torch.cuda.max_memory_allocated() / 1024**2
# Warmup runs
measure_memory(False)
measure_memory(True)
# Actual measurements
for use_save_original in [False, True]:
peak = measure_memory(use_save_original)
print(f"save_original_input={use_save_original}: {peak:.1f} MB")
# END_SAVE_ORIGINAL_INPUT
print("# END_SAVE_ORIGINAL_INPUT")
...@@ -39,6 +39,14 @@ Transformer Engine documentation ...@@ -39,6 +39,14 @@ Transformer Engine documentation
api/common api/common
api/framework api/framework
.. toctree::
:hidden:
:caption: Features
features/low_precision_training/index.rst
.. toctree:: .. toctree::
:hidden: :hidden:
:caption: Examples and Tutorials :caption: Examples and Tutorials
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment