Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 850 480" width="850" height="480" style="display: block; margin: 0 auto;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.arrow { stroke: #616161; stroke-width: 1.5; fill: none; marker-end: url(#arrowhead); }
.section-label { font-family: 'Segoe UI', Arial, sans-serif; font-size: 16px; font-weight: 600; fill: #424242; text-anchor: start; }
</style>
<marker id="arrowhead" markerWidth="3" markerHeight="3" refX="3" refY="1.5" orient="auto">
<polygon points="0 0, 3 1.5, 0 3" fill="#616161" />
</marker>
</defs>
<!-- Title -->
<text x="425" y="30" class="title" style="text-anchor: middle;">Transformer Layer – default precision of operation in low precision recipe</text>
<!-- Row 1: Input → Layer Norm → QKV Linear → QK^T → Softmax -->
<rect x="20" y="60" width="115" height="50" rx="5" class="hp"/>
<text x="77" y="90" class="text">Input</text>
<path d="M 135 85 L 158 85" class="arrow"/>
<rect x="158" y="60" width="115" height="50" rx="5" class="hp"/>
<text x="215" y="90" class="text">Layer Norm</text>
<path d="M 273 85 L 296 85" class="arrow"/>
<rect x="296" y="60" width="115" height="50" rx="5" class="gemm"/>
<text x="353" y="90" class="text">QKV Linear</text>
<path d="M 411 85 L 434 85" class="arrow"/>
<rect x="434" y="60" width="115" height="50" rx="5" class="hp"/>
<text x="491" y="90" class="text">QK^T</text>
<path d="M 549 85 L 572 85" class="arrow"/>
<rect x="572" y="60" width="115" height="50" rx="5" class="hp"/>
<text x="629" y="90" class="text">Softmax</text>
<!-- Row 2: Attn * V → Output Linear → Dropout + Add -->
<path d="M 629 110 L 629 145" class="arrow"/>
<rect x="572" y="145" width="115" height="50" rx="5" class="hp"/>
<text x="629" y="175" class="text">Scores * V</text>
<path d="M 572 170 L 549 170" class="arrow"/>
<rect x="434" y="145" width="115" height="50" rx="5" class="gemm"/>
<text x="491" y="175" class="text">Output Linear</text>
<path d="M 434 170 L 273 170" class="arrow"/>
<rect x="158" y="145" width="115" height="50" rx="5" class="hp"/>
<text x="215" y="175" class="text">Dropout + Add</text>
<!-- Row 3: Layer Norm → FFN Linear 1 → GELU → FFN Linear 2 → Output -->
<path d="M 215 195 L 215 230" class="arrow"/>
<rect x="158" y="230" width="115" height="50" rx="5" class="hp"/>
<text x="215" y="260" class="text">Layer Norm</text>
<path d="M 273 255 L 296 255" class="arrow"/>
<rect x="296" y="230" width="115" height="50" rx="5" class="gemm"/>
<text x="353" y="260" class="text">FFN Linear 1</text>
<path d="M 411 255 L 434 255" class="arrow"/>
<rect x="434" y="230" width="115" height="50" rx="5" class="hp"/>
<text x="491" y="260" class="text">GELU</text>
<path d="M 549 255 L 572 255" class="arrow"/>
<rect x="572" y="230" width="115" height="50" rx="5" class="gemm"/>
<text x="629" y="260" class="text">FFN Linear 2</text>
<path d="M 687 255 L 710 255" class="arrow"/>
<rect x="710" y="230" width="115" height="50" rx="5" class="hp"/>
<text x="767" y="260" class="text">Output</text>
<!-- Memory State Section -->
<text x="20" y="325" class="section-label">Memory State:</text>
<!-- Parameters -->
<rect x="20" y="340" width="180" height="45" rx="5" class="hp"/>
<text x="110" y="365" class="text">Parameters</text>
<!-- Gradients -->
<rect x="225" y="340" width="140" height="45" rx="5" class="hp"/>
<text x="295" y="365" class="text">Gradients</text>
<!-- Legend -->
<g transform="translate(20, 415)">
<!-- High Precision -->
<rect x="0" y="0" width="80" height="40" rx="5" class="hp"/>
<text x="95" y="23" class="text" style="text-anchor: start;">Higher Precision (FP32/BF16/FP16)</text>
<!-- Low Precision -->
<rect x="400" y="0" width="80" height="40" rx="5" class="gemm"/>
<text x="495" y="23" class="text" style="text-anchor: start;">Lower Precision (FP8, MXFP8 etc.)</text>
</g>
</svg>
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Introduction
===================================
Transformer Engine accelerates deep learning on NVIDIA GPUs in several ways,
with low precision training being one of the most important.
This chapter introduces mixed precision training and FP8 support.
Training in BF16/FP16
---------------------
Deep learning traditionally uses 32-bit floating-point (FP32) numbers.
NVIDIA GPUs support lower precision formats—FP16 since Pascal, BF16 since Ampere—which offer higher throughput and lower memory usage.
Let's compare these formats.
.. raw:: html
:file: img/fp_formats_comparison.svg
*Figure 1: Comparison of FP32, BF16, and FP16 floating-point formats showing bit allocation for sign, exponent, and mantissa.*
The key differences between these formats are:
* **FP32** (32 bits total): 1 sign bit + 8 exponent bits + 23 mantissa bits – standard single-precision format
* **BF16** (16 bits total): 1 sign bit + 8 exponent bits + 7 mantissa bits – maintains FP32's exponent range but has reduced precision
* **FP16** (16 bits total): 1 sign bit + 5 exponent bits + 10 mantissa bits – reduced range but higher precision than BF16
BF16's advantage is that it shares the same exponent range as FP32,
making it easier to convert between the two formats without overflow/underflow issues.
FP16 offers better precision for smaller values but has a limited dynamic range,
which results in the need to perform loss scaling to avoid overflow/underflow—see `this paper on loss scaling <https://arxiv.org/pdf/1710.03740>`__ for more details.
**Mixed precision**
Not all operations should be run in reduced precision to preserve accuracy.
Modern deep learning frameworks use *mixed precision training*,
where different operations use different precisions based on their numerical properties:
* Matrix multiplications are compute-heavy and remain numerically stable at lower precision, making them ideal candidates for acceleration.
* Operations like layer normalization and softmax can work with low precision inputs and outputs, but may use high precision internally or for their weights.
* Operations like loss computation and exponentiation need high precision throughout.
**Master weights**
Another consideration in mixed precision training is how to store the model weights.
Lower precision formats like FP16 and BF16 have limited representational granularity,
which becomes problematic during gradient updates.
When a small gradient is added to a not so small weight stored in low precision,
the result may round back to the original value if the update falls below the format's precision threshold.
Moreover, some elements of the gradient itself can be too small to be represented in low precision,
especially after the accumulation from multiple GPUs in the data parallel training setting.
The solution is to maintain *master weights* in FP32.
During training, weights are cast to lower precision for forward and backward passes,
but the gradient updates are applied to the full-precision master copy.
This ensures that even small gradients accumulate correctly over time.
There are two common software approaches to storing master weights:
* *In the optimizer*:
The model holds low-precision weights,
while the optimizer maintains FP32 copies alongside momentum and other state.
During each step,
the optimizer updates its FP32 copy and casts the result back to the model's low-precision weights.
This approach makes it easier to shard master weights together with other optimizer state, for example in ZeRO optimizer.
Since the casting happens only during the optimizer step, this approach is also faster when optimizer runs less frequently than the model, e.g. when performing gradient accumulation or pipeline parallel training.
* *In the model*:
The model stores weights directly in FP32,
and they are cast to lower precision on-the-fly during forward and backward passes.
This approach works seamlessly with any standard optimizer, requiring no special support.
.. raw:: html
:file: img/master_weights_approaches.svg
*Figure 2: Three approaches to weight storage—low precision only (no master weights), master weights stored in the model, and master weights stored in the optimizer.*
.. tabs::
.. tab:: PyTorch
The PyTorch API of Transformer Engine provides several mechanisms to control precision:
* **Weight precision**: Use the ``params_dtype`` argument in any TE layer constructor.
* **Computation precision**: Use the ``torch.autocast`` context manager. When enabled, inputs are cast to the autocast dtype before computation.
* **Input dtype**: When ``torch.autocast`` is not used, the input tensor's dtype determines the computation precision. In this case, inputs and parameters must have matching dtypes.
.. literalinclude:: bf16_fp16_training_pytorch.py
:language: python
:start-after: # START_BF16_FP16_TRAINING
:end-before: # END_BF16_FP16_TRAINING
.. tab:: JAX
The JAX API of Transformer Engine provides two mechanisms to control precision:
* **Weight precision**: Use the ``dtype`` argument in any TE layer constructor.
* **Computation precision**: Determined by the dtype of the input tensor.
For training with master weights in FP32 and computation in BF16,
cast the input tensor to BF16 before passing it to the layer.
.. literalinclude:: bf16_fp16_training_jax.py
:language: python
:start-after: # START_BF16_FP16_TRAINING
:end-before: # END_BF16_FP16_TRAINING
Lower precisions
----------------
Transformer Engine's primary feature is supporting even lower precision than BF16/FP16, such as FP8, MXFP8, NVFP4, etc.
The logic of these precisions is more complicated than the logic of BF16/FP16 – they require scaling factors to
properly represent the full range of values in the tensor. Sometimes it is one scaling factor per tensor,
sometimes it is one scaling factor per block of values. A precision format combined with the logic for training
is called **a recipe**.
In this section we present common logic for all the recipes. Each one of them is described in more detail in a separate section later.
Let's now see how we can train in lower precisions in supported frameworks.
.. tabs::
.. tab:: PyTorch
The PyTorch API of Transformer Engine provides an ``autocast`` context manager to control precision.
It's similar to the ``torch.autocast`` context manager, but tailored for low precision training.
The most important argument is the ``recipe`` argument, which accepts objects inheriting from
:class:`~transformer_engine.common.recipe.Recipe`.
Forward computations need to be performed inside the ``autocast`` context manager,
while the ``.backward()`` call should be outside of it (it inherits the setting from the
corresponding forward pass).
Here is a basic example:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_pytorch.py
:language: python
:start-after: # START_AUTOCAST_BASIC
:end-before: # END_AUTOCAST_BASIC
You can use multiple recipes in the same model in the following ways:
**Sequential contexts** – apply different recipes to different parts of your model:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_pytorch.py
:language: python
:start-after: # START_AUTOCAST_SEQUENTIAL
:end-before: # END_AUTOCAST_SEQUENTIAL
**Nested contexts** – the inner context overrides the outer one for its scope:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_pytorch.py
:language: python
:start-after: # START_AUTOCAST_NESTED
:end-before: # END_AUTOCAST_NESTED
.. tab:: JAX
The JAX API of Transformer Engine provides an ``autocast`` context manager similar to PyTorch.
The key difference is that in JAX, model initialization must happen inside the ``autocast`` context
to properly capture quantization metadata in the parameter tree.
Here is a basic example:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_jax.py
:language: python
:start-after: # START_AUTOCAST_BASIC
:end-before: # END_AUTOCAST_BASIC
You can use multiple recipes in the same model in the following ways:
**Sequential contexts** – apply different recipes to different parts of your model:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_jax.py
:language: python
:start-after: # START_AUTOCAST_SEQUENTIAL
:end-before: # END_AUTOCAST_SEQUENTIAL
**Nested contexts** – the inner context overrides the outer one for its scope:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_jax.py
:language: python
:start-after: # START_AUTOCAST_NESTED
:end-before: # END_AUTOCAST_NESTED
.. note::
Python context managers like ``autocast`` may interact unexpectedly with JAX's JIT compilation.
For finer-grained control, consider passing the recipe directly to TE modules instead.
See the `TE JAX Integration notebook <https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_jax_integration.ipynb>`_
for details.
**Mixed precision with 8- or 4-bit precisions**
From now on, we will refer to FP8/MXFP8/NVFP4 etc. as *low precision*
and to FP32/BF16/FP16 as *high precision*. This terminology will be
used throughout the rest of the documentation.
Not all operations run in low precision:
- **Linear operations**: run in low precision.
- **Attention computations**: run in high precision by default (some recipes allow low precision as an option).
- **Other operations** (layer normalization, softmax, etc.): run in high precision.
Within high-precision operations, there are two categories:
- **Configurable precision**: most operations run in parameter precision (FP32/BF16/FP16) or the precision specified by ``torch.autocast``.
- **Fixed FP32 precision**: some operations, or parts of operations—such as the division in layernorm—always run in FP32, regardless of other settings.
.. raw:: html
:file: img/mixed_precision_operations.svg
*Figure 3: Default precision of operations in a TransformerLayer forward pass. Only linear operations are in lower precision. Dot product attention is shown as three separate operations (QK^T, Softmax, Scores * V), though in practice these may be fused into a single kernel.*
**Linear layer data flow**
Let's see how data flow of a linear layer works by default on a single H100 GPU with FP8 precision:
H100 (Hopper) architecture natively supports FP8 Matrix Multiplication only in **TN** layout (Transpose-NoTranspose),
so GEMM with tensors ``A`` and ``B`` returns ``B * A^T``.
*Forward pass*
* Input is quantized to FP8 – both ``input`` and ``input^T`` quantized versions are created.
* Weights are stored in high precision and quantized to low precision before the GEMM – both ``weight`` and ``weight^T`` quantized versions are created.
* FP8 GEMM with layout **TN** is run with ``weight`` and ``input`` tensors,
* Outputs – ``input * weight^T`` tensor – are returned in high precision.
*Backward pass*
* Output gradients are quantized to FP8 – both ``output_grad`` and ``output_grad^T`` quantized versions are created.
* FP8 GEMM with layout **TN** is performed with ``weight^T`` and ``output_grad`` tensors to compute input gradients.
* FP8 GEMM with layout **TN** is performed with ``input^T`` and ``output_grad^T`` tensors to compute weight gradients.
* Input gradients – ``output_grad * weight`` tensor – are returned in high precision.
* Weight gradients – ``output_grad^T * input`` tensor – are returned in high precision.
.. raw:: html
:file: img/fp8_linear_flow.svg
*Figure 4: Forward pass of a Linear layer with low precision data flow.*
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 60 500 380" width="100%" style="max-width: 500px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 18px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 14px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 12px sans-serif; fill: #555; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
/* Tensor colors */
.fp8-tensor { fill: #87CEEB; stroke: #444; stroke-width: 2; }
.fp8-block { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.fp8-block-alt { fill: #5F9FCC; stroke: #555; stroke-width: 1.5; }
/* Scaling factor colors */
.scale-factor { fill: #FFA500; stroke: #444; stroke-width: 2; }
.grid-line { stroke: #444; stroke-width: 2; }
.boundary-line { stroke: #444; stroke-width: 2; }
</style>
</defs>
<!-- MXFP8 Scaling -->
<g id="mxfp8-scaling">
<text x="250" y="85" class="title">MXFP8</text>
<text x="250" y="108" class="label" style="font-size: 12px;">(One scaling factor per 32 elements)</text>
<!-- FP8 Tensor split into many small blocks (40×10) -->
<g id="tensor-blocks">
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="250.0" y="140.0" width="40" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="130.0" y="190.0" width="240" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Blocks ONLY where they don't overlap with white cross -->
<rect x="130" y="140" width="40" height="10" class="fp8-block"/>
<rect x="170" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="140" width="40" height="10" class="fp8-block"/>
<rect x="290" y="140" width="40" height="10" class="fp8-block"/>
<rect x="330" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="150" width="40" height="10" class="fp8-block"/>
<rect x="210" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="150" width="40" height="10" class="fp8-block"/>
<rect x="130" y="160" width="40" height="10" class="fp8-block"/>
<rect x="170" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="160" width="40" height="10" class="fp8-block"/>
<rect x="290" y="160" width="40" height="10" class="fp8-block"/>
<rect x="330" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="170" width="40" height="10" class="fp8-block"/>
<rect x="210" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="170" width="40" height="10" class="fp8-block"/>
<rect x="130" y="180" width="40" height="10" class="fp8-block"/>
<rect x="170" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="180" width="40" height="10" class="fp8-block"/>
<rect x="290" y="180" width="40" height="10" class="fp8-block"/>
<rect x="330" y="180" width="40" height="10" class="fp8-block-alt"/>
<!-- Three dots in VERTICAL white bar -->
<text x="270" y="167.5" class="dots-text"></text>
<text x="270" y="242.5" class="dots-text"></text>
<!-- Three dots in HORIZONTAL white bar -->
<text x="190" y="205" class="dots-text"></text>
<text x="330" y="205" class="dots-text"></text>
<!-- ONE diagonal dot at intersection -->
<text x="270" y="205" class="dots-text" transform="rotate(45 270 205)"></text>
<!-- Bottom rows (y >= 220 after horizontal white bar) -->
<rect x="130" y="220" width="40" height="10" class="fp8-block"/>
<rect x="170" y="220" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="220" width="40" height="10" class="fp8-block"/>
<rect x="290" y="220" width="40" height="10" class="fp8-block"/>
<rect x="330" y="220" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="230" width="40" height="10" class="fp8-block"/>
<rect x="210" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="230" width="40" height="10" class="fp8-block"/>
<rect x="130" y="240" width="40" height="10" class="fp8-block"/>
<rect x="170" y="240" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="240" width="40" height="10" class="fp8-block"/>
<rect x="290" y="240" width="40" height="10" class="fp8-block"/>
<rect x="330" y="240" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="250" width="40" height="10" class="fp8-block"/>
<rect x="210" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="250" width="40" height="10" class="fp8-block"/>
<!-- Main outline -->
<rect x="130.0" y="140.0" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Scaling factors tensor - 3+2 columns of 10px squares -->
<g id="scale-factors">
<!-- Orange background -->
<rect x="215" y="285" width="70" height="120" fill="#FFA500"/>
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="245" y="285" width="20" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="215" y="335" width="70" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Grid lines showing 10x10 squares (3 left + 2 right columns) -->
<!-- Vertical lines every 10px (skipping white space) -->
<!-- Left 3 columns -->
<line x1="225" y1="285" x2="225" y2="335" class="grid-line" stroke-width="1"/>
<line x1="235" y1="285" x2="235" y2="335" class="grid-line" stroke-width="1"/>
<line x1="245" y1="285" x2="245" y2="335" class="grid-line" stroke-width="1"/>
<!-- Right 2 columns -->
<line x1="265" y1="285" x2="265" y2="335" class="grid-line" stroke-width="1"/>
<line x1="275" y1="285" x2="275" y2="335" class="grid-line" stroke-width="1"/>
<line x1="285" y1="285" x2="285" y2="335" class="grid-line" stroke-width="1"/>
<!-- Bottom sections -->
<line x1="225" y1="365" x2="225" y2="405" class="grid-line" stroke-width="1"/>
<line x1="235" y1="365" x2="235" y2="405" class="grid-line" stroke-width="1"/>
<line x1="245" y1="365" x2="245" y2="405" class="grid-line" stroke-width="1"/>
<line x1="265" y1="365" x2="265" y2="405" class="grid-line" stroke-width="1"/>
<line x1="275" y1="365" x2="275" y2="405" class="grid-line" stroke-width="1"/>
<line x1="285" y1="365" x2="285" y2="405" class="grid-line" stroke-width="1"/>
<!-- Horizontal lines every 10px -->
<line x1="215" y1="295" x2="245" y2="295" class="grid-line" stroke-width="1"/>
<line x1="265" y1="295" x2="285" y2="295" class="grid-line" stroke-width="1"/>
<line x1="215" y1="305" x2="245" y2="305" class="grid-line" stroke-width="1"/>
<line x1="265" y1="305" x2="285" y2="305" class="grid-line" stroke-width="1"/>
<line x1="215" y1="315" x2="245" y2="315" class="grid-line" stroke-width="1"/>
<line x1="265" y1="315" x2="285" y2="315" class="grid-line" stroke-width="1"/>
<line x1="215" y1="325" x2="245" y2="325" class="grid-line" stroke-width="1"/>
<line x1="265" y1="325" x2="285" y2="325" class="grid-line" stroke-width="1"/>
<!-- Top bottom boundaries -->
<line x1="215" y1="335" x2="245" y2="335" class="grid-line" stroke-width="1"/>
<line x1="265" y1="335" x2="285" y2="335" class="grid-line" stroke-width="1"/>
<line x1="215" y1="365" x2="245" y2="365" class="grid-line" stroke-width="1"/>
<line x1="265" y1="365" x2="285" y2="365" class="grid-line" stroke-width="1"/>
<line x1="215" y1="375" x2="245" y2="375" class="grid-line" stroke-width="1"/>
<line x1="265" y1="375" x2="285" y2="375" class="grid-line" stroke-width="1"/>
<line x1="215" y1="385" x2="245" y2="385" class="grid-line" stroke-width="1"/>
<line x1="265" y1="385" x2="285" y2="385" class="grid-line" stroke-width="1"/>
<line x1="215" y1="395" x2="245" y2="395" class="grid-line" stroke-width="1"/>
<line x1="265" y1="395" x2="285" y2="395" class="grid-line" stroke-width="1"/>
<!-- Bottom boundaries -->
<line x1="215" y1="405" x2="245" y2="405" class="grid-line" stroke-width="1"/>
<line x1="265" y1="405" x2="285" y2="405" class="grid-line" stroke-width="1"/>
<!-- Main outline -->
<rect x="215" y="285" width="70" height="120" fill="none" stroke="#444" stroke-width="2"/>
<!-- Three dots -->
<text x="255" y="312.5" class="dots-text" style="font-size: 14px;"></text>
<text x="255" y="387.5" class="dots-text" style="font-size: 14px;"></text>
<text x="230" y="350" class="dots-text" style="font-size: 14px;"></text>
<text x="275" y="350" class="dots-text" style="font-size: 14px;"></text>
<text x="255" y="350" class="dots-text" style="font-size: 14px;" transform="rotate(45 255 350)"></text>
</g>
<text x="250" y="430" class="small-text" text-anchor="middle">E8M0 scaling factors (one per 32 elements)</text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 650 450" width="100%" style="max-width: 650px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 18px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 14px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 12px sans-serif; fill: #555; text-anchor: middle; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
.fp8-block { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.fp8-block-alt { fill: #5F9FCC; stroke: #555; stroke-width: 1.5; }
.scale-block { fill: #FFA500; stroke: #555; stroke-width: 1.5; }
</style>
</defs>
<!-- ROWWISE SECTION -->
<text x="325" y="30" class="title">Rowwise (1x32 blocks)</text>
<!-- Rowwise Data Tensor -->
<g id="rowwise-tensor">
<text x="160" y="55" class="small-text">Data</text>
<rect x="40" y="70" width="40" height="10" class="fp8-block"/>
<rect x="80" y="70" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="70" width="40" height="10" class="fp8-block"/>
<rect x="200" y="70" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="70" width="40" height="10" class="fp8-block"/>
<rect x="40" y="80" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="80" width="40" height="10" class="fp8-block"/>
<rect x="120" y="80" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="80" width="40" height="10" class="fp8-block"/>
<rect x="240" y="80" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="90" width="40" height="10" class="fp8-block"/>
<rect x="80" y="90" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="90" width="40" height="10" class="fp8-block"/>
<rect x="200" y="90" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="90" width="40" height="10" class="fp8-block"/>
<rect x="40" y="100" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="100" width="40" height="10" class="fp8-block"/>
<rect x="120" y="100" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="100" width="40" height="10" class="fp8-block"/>
<rect x="240" y="100" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="110" width="40" height="10" class="fp8-block"/>
<rect x="80" y="110" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="110" width="40" height="10" class="fp8-block"/>
<rect x="200" y="110" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="110" width="40" height="10" class="fp8-block"/>
<rect x="40" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="150" width="40" height="10" class="fp8-block"/>
<rect x="120" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="150" width="40" height="10" class="fp8-block"/>
<rect x="240" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="160" width="40" height="10" class="fp8-block"/>
<rect x="80" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="160" width="40" height="10" class="fp8-block"/>
<rect x="200" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="160" width="40" height="10" class="fp8-block"/>
<rect x="40" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="170" width="40" height="10" class="fp8-block"/>
<rect x="120" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="170" width="40" height="10" class="fp8-block"/>
<rect x="240" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="180" width="40" height="10" class="fp8-block"/>
<rect x="80" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="180" width="40" height="10" class="fp8-block"/>
<rect x="200" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="180" width="40" height="10" class="fp8-block"/>
<rect x="40" y="70" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
<text x="180" y="97.5" class="dots-text"></text>
<text x="180" y="172.5" class="dots-text"></text>
<text x="100" y="135" class="dots-text"></text>
<text x="240" y="135" class="dots-text"></text>
<text x="180" y="135" class="dots-text" transform="rotate(45 180 135)"></text>
</g>
<!-- Rowwise Scale Tensor -->
<g id="rowwise-scales">
<text x="485" y="55" class="small-text">Scales</text>
<!-- Rows 1-5 -->
<rect x="380" y="70" width="10" height="10" class="scale-block"/>
<rect x="390" y="70" width="10" height="10" class="scale-block"/>
<rect x="400" y="70" width="10" height="10" class="scale-block"/>
<rect x="450" y="70" width="10" height="10" class="scale-block"/>
<rect x="460" y="70" width="10" height="10" class="scale-block"/>
<rect x="380" y="80" width="10" height="10" class="scale-block"/>
<rect x="390" y="80" width="10" height="10" class="scale-block"/>
<rect x="400" y="80" width="10" height="10" class="scale-block"/>
<rect x="450" y="80" width="10" height="10" class="scale-block"/>
<rect x="460" y="80" width="10" height="10" class="scale-block"/>
<rect x="380" y="90" width="10" height="10" class="scale-block"/>
<rect x="390" y="90" width="10" height="10" class="scale-block"/>
<rect x="400" y="90" width="10" height="10" class="scale-block"/>
<rect x="450" y="90" width="10" height="10" class="scale-block"/>
<rect x="460" y="90" width="10" height="10" class="scale-block"/>
<rect x="380" y="100" width="10" height="10" class="scale-block"/>
<rect x="390" y="100" width="10" height="10" class="scale-block"/>
<rect x="400" y="100" width="10" height="10" class="scale-block"/>
<rect x="450" y="100" width="10" height="10" class="scale-block"/>
<rect x="460" y="100" width="10" height="10" class="scale-block"/>
<rect x="380" y="110" width="10" height="10" class="scale-block"/>
<rect x="390" y="110" width="10" height="10" class="scale-block"/>
<rect x="400" y="110" width="10" height="10" class="scale-block"/>
<rect x="450" y="110" width="10" height="10" class="scale-block"/>
<rect x="460" y="110" width="10" height="10" class="scale-block"/>
<!-- Gap rows -->
<rect x="380" y="150" width="10" height="10" class="scale-block"/>
<rect x="390" y="150" width="10" height="10" class="scale-block"/>
<rect x="400" y="150" width="10" height="10" class="scale-block"/>
<rect x="450" y="150" width="10" height="10" class="scale-block"/>
<rect x="460" y="150" width="10" height="10" class="scale-block"/>
<rect x="380" y="160" width="10" height="10" class="scale-block"/>
<rect x="390" y="160" width="10" height="10" class="scale-block"/>
<rect x="400" y="160" width="10" height="10" class="scale-block"/>
<rect x="450" y="160" width="10" height="10" class="scale-block"/>
<rect x="460" y="160" width="10" height="10" class="scale-block"/>
<rect x="380" y="170" width="10" height="10" class="scale-block"/>
<rect x="390" y="170" width="10" height="10" class="scale-block"/>
<rect x="400" y="170" width="10" height="10" class="scale-block"/>
<rect x="450" y="170" width="10" height="10" class="scale-block"/>
<rect x="460" y="170" width="10" height="10" class="scale-block"/>
<rect x="380" y="180" width="10" height="10" class="scale-block"/>
<rect x="390" y="180" width="10" height="10" class="scale-block"/>
<rect x="400" y="180" width="10" height="10" class="scale-block"/>
<rect x="450" y="180" width="10" height="10" class="scale-block"/>
<rect x="460" y="180" width="10" height="10" class="scale-block"/>
<rect x="380" y="70" width="90" height="120" fill="none" stroke="#444" stroke-width="2"/>
<text x="430" y="97.5" class="dots-text"></text>
<text x="430" y="172.5" class="dots-text"></text>
<text x="400" y="135" class="dots-text"></text>
<text x="460" y="135" class="dots-text"></text>
<text x="430" y="135" class="dots-text" transform="rotate(45 430 135)"></text>
</g>
<!-- COLUMNWISE SECTION -->
<text x="325" y="230" class="title">Columnwise (32x1 blocks)</text>
<!-- Columnwise Data Tensor -->
<g id="colwise-tensor">
<text x="160" y="255" class="small-text">Data</text>
<rect x="40" y="270" width="10" height="40" class="fp8-block"/>
<rect x="50" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="60" y="270" width="10" height="40" class="fp8-block"/>
<rect x="70" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="80" y="270" width="10" height="40" class="fp8-block"/>
<rect x="90" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="100" y="270" width="10" height="40" class="fp8-block"/>
<rect x="110" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="120" y="270" width="10" height="40" class="fp8-block"/>
<rect x="130" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="180" y="270" width="10" height="40" class="fp8-block"/>
<rect x="190" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="200" y="270" width="10" height="40" class="fp8-block"/>
<rect x="210" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="220" y="270" width="10" height="40" class="fp8-block"/>
<rect x="230" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="240" y="270" width="10" height="40" class="fp8-block"/>
<rect x="250" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="260" y="270" width="10" height="40" class="fp8-block"/>
<rect x="270" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="40" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="50" y="350" width="10" height="40" class="fp8-block"/>
<rect x="60" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="70" y="350" width="10" height="40" class="fp8-block"/>
<rect x="80" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="90" y="350" width="10" height="40" class="fp8-block"/>
<rect x="100" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="110" y="350" width="10" height="40" class="fp8-block"/>
<rect x="120" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="130" y="350" width="10" height="40" class="fp8-block"/>
<rect x="180" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="190" y="350" width="10" height="40" class="fp8-block"/>
<rect x="200" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="210" y="350" width="10" height="40" class="fp8-block"/>
<rect x="220" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="230" y="350" width="10" height="40" class="fp8-block"/>
<rect x="240" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="250" y="350" width="10" height="40" class="fp8-block"/>
<rect x="260" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="270" y="350" width="10" height="40" class="fp8-block"/>
<rect x="40" y="270" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
<text x="160" y="296" class="dots-text"></text>
<text x="160" y="376" class="dots-text"></text>
<text x="90" y="336" class="dots-text" transform="rotate(90 90 336)"></text>
<text x="230" y="336" class="dots-text" transform="rotate(90 230 336)"></text>
<text x="160" y="336" class="dots-text" transform="rotate(45 160 336)"></text>
</g>
<!-- Columnwise Scale Tensor - TRANSPOSED -->
<g id="colwise-scales">
<text x="485" y="255" class="small-text">Scales</text>
<!-- Row 1 -->
<rect x="370" y="300" width="10" height="10" class="scale-block"/>
<rect x="380" y="300" width="10" height="10" class="scale-block"/>
<rect x="390" y="300" width="10" height="10" class="scale-block"/>
<rect x="400" y="300" width="10" height="10" class="scale-block"/>
<rect x="410" y="300" width="10" height="10" class="scale-block"/>
<rect x="420" y="300" width="10" height="10" class="scale-block"/>
<rect x="430" y="300" width="10" height="10" class="scale-block"/>
<rect x="440" y="300" width="10" height="10" class="scale-block"/>
<rect x="450" y="300" width="10" height="10" class="scale-block"/>
<rect x="460" y="300" width="10" height="10" class="scale-block"/>
<rect x="510" y="300" width="10" height="10" class="scale-block"/>
<rect x="520" y="300" width="10" height="10" class="scale-block"/>
<rect x="530" y="300" width="10" height="10" class="scale-block"/>
<rect x="540" y="300" width="10" height="10" class="scale-block"/>
<rect x="550" y="300" width="10" height="10" class="scale-block"/>
<rect x="560" y="300" width="10" height="10" class="scale-block"/>
<rect x="570" y="300" width="10" height="10" class="scale-block"/>
<rect x="580" y="300" width="10" height="10" class="scale-block"/>
<rect x="590" y="300" width="10" height="10" class="scale-block"/>
<rect x="600" y="300" width="10" height="10" class="scale-block"/>
<!-- Row 2 (gap) -->
<rect x="370" y="330" width="10" height="10" class="scale-block"/>
<rect x="380" y="330" width="10" height="10" class="scale-block"/>
<rect x="390" y="330" width="10" height="10" class="scale-block"/>
<rect x="400" y="330" width="10" height="10" class="scale-block"/>
<rect x="410" y="330" width="10" height="10" class="scale-block"/>
<rect x="420" y="330" width="10" height="10" class="scale-block"/>
<rect x="430" y="330" width="10" height="10" class="scale-block"/>
<rect x="440" y="330" width="10" height="10" class="scale-block"/>
<rect x="450" y="330" width="10" height="10" class="scale-block"/>
<rect x="460" y="330" width="10" height="10" class="scale-block"/>
<rect x="510" y="330" width="10" height="10" class="scale-block"/>
<rect x="520" y="330" width="10" height="10" class="scale-block"/>
<rect x="530" y="330" width="10" height="10" class="scale-block"/>
<rect x="540" y="330" width="10" height="10" class="scale-block"/>
<rect x="550" y="330" width="10" height="10" class="scale-block"/>
<rect x="560" y="330" width="10" height="10" class="scale-block"/>
<rect x="570" y="330" width="10" height="10" class="scale-block"/>
<rect x="580" y="330" width="10" height="10" class="scale-block"/>
<rect x="590" y="330" width="10" height="10" class="scale-block"/>
<rect x="600" y="330" width="10" height="10" class="scale-block"/>
<rect x="370" y="300" width="240" height="40" fill="none" stroke="#444" stroke-width="2"/>
<text x="490" y="320" class="dots-text"></text>
<text x="430" y="320" class="dots-text"></text>
<text x="560" y="320" class="dots-text"></text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 -15 900 445" width="100%" style="max-width: 900px;">
<style>
.scale-fill { fill: #FFA500; stroke: #444; stroke-width: 2; }
.scale-fill-nostroke { fill: #FFA500; stroke: none; }
.grid-line { stroke: #444; stroke-width: 2; fill: none; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 12px sans-serif; fill: #333; text-anchor: middle; }
.num-text { font: bold 12px sans-serif; fill: #333; text-anchor: middle; dominant-baseline: middle; }
.arrow-line { stroke: #444; stroke-width: 2; fill: none; }
.outer-border { fill: none; stroke: #444; stroke-width: 1.5; }
.inner-line { stroke: #444; stroke-width: 1; stroke-dasharray: 3,2; }
.num-text-small { font: bold 11px sans-serif; fill: #333; text-anchor: middle; dominant-baseline: middle; }
.dots-text-small { font: bold 14px sans-serif; fill: #333; text-anchor: middle; dominant-baseline: middle; }
</style>
<defs>
<marker id="arrowhead" markerWidth="10" markerHeight="10" refX="8" refY="5" orient="auto">
<polygon points="0 0, 10 5, 0 10" fill="#444" />
</marker>
</defs>
<!-- ======== PART 1: Linearization (from mxfp8_scale_linearize.svg) ======== -->
<g id="linearization">
<!-- Left: Scaling factors grid -->
<!-- Main rectangle with white background -->
<rect x="40" y="40" width="120" height="240" fill="#FFFFFF" stroke="#444" stroke-width="2"/>
<!-- Orange grid area (upper-left) -->
<rect x="40" y="40" width="60" height="180" fill="#FFA500" stroke="#444" stroke-width="2"/>
<!-- Grid lines for narrow blocks (3 columns x 3 rows) -->
<line x1="60" y1="40" x2="60" y2="220" class="grid-line"/>
<line x1="80" y1="40" x2="80" y2="220" class="grid-line"/>
<line x1="40" y1="100" x2="100" y2="100" class="grid-line"/>
<line x1="40" y1="160" x2="100" y2="160" class="grid-line"/>
<!-- Numbers in orange cells -->
<!-- Row 1 -->
<text x="50" y="70" class="num-text">1</text>
<text x="70" y="70" class="num-text">2</text>
<text x="90" y="70" class="num-text">3</text>
<!-- Row 2 -->
<text x="50" y="125" class="num-text">K</text>
<text x="50" y="137" class="num-text">+</text>
<text x="50" y="149" class="num-text">1</text>
<text x="70" y="125" class="num-text">K</text>
<text x="70" y="137" class="num-text">+</text>
<text x="70" y="149" class="num-text">2</text>
<text x="90" y="125" class="num-text">K</text>
<text x="90" y="137" class="num-text">+</text>
<text x="90" y="149" class="num-text">3</text>
<!-- Row 3 -->
<text x="50" y="180" class="num-text">2K</text>
<text x="50" y="192" class="num-text">+</text>
<text x="50" y="204" class="num-text">1</text>
<text x="70" y="180" class="num-text">2K</text>
<text x="70" y="192" class="num-text">+</text>
<text x="70" y="204" class="num-text">1</text>
<text x="90" y="180" class="num-text">2K</text>
<text x="90" y="192" class="num-text">+</text>
<text x="90" y="204" class="num-text">3</text>
<!-- Dots in white area (right side) -->
<text x="125" y="90" class="dots-text"></text>
<text x="125" y="150" class="dots-text"></text>
<text x="125" y="210" class="dots-text"></text>
<!-- Dots in white area (bottom) -->
<text x="60" y="260" class="dots-text"></text>
<text x="90" y="260" class="dots-text"></text>
<text x="125" y="260" class="dots-text"></text>
<!-- Arrow pointing to first block with label 128x4 -->
<text x="50" y="0" class="label" text-anchor="middle">128x4</text>
<path d="M 50 5 L 50 38" class="arrow-line" marker-end="url(#arrowhead)"/>
<!-- Arrow -->
<path d="M 200 150 L 300 150" class="arrow-line" marker-end="url(#arrowhead)"/>
<!-- Right: Linearized 1D array -->
<!-- Main rectangle with white background -->
<rect x="340" y="140" width="520" height="20" fill="#FFFFFF" stroke="#444" stroke-width="2"/>
<!-- Orange cells -->
<rect x="340" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="375" y="150" class="num-text">1</text>
<rect x="410" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="445" y="150" class="num-text">2</text>
<rect x="480" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="515" y="150" class="dots-text"></text>
<rect x="550" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="585" y="150" class="num-text">K + 1</text>
<rect x="620" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="655" y="150" class="num-text">K + 2</text>
<!-- White area with dots -->
<text x="725" y="150" class="dots-text"></text>
<!-- Arrow pointing to first linearized block with label 1x512 -->
<text x="375" y="75" class="label" text-anchor="middle">1x512</text>
<path d="M 375 80 L 375 138" class="arrow-line" marker-end="url(#arrowhead)"/>
</g>
<!-- ======== Connection: Arrow from "1" to bottom left block with brace ======== -->
<g id="connection">
<!-- Label above brace -->
<text x="250" y="335" class="label">128 4-bit elements</text>
<!-- Curly brace on top of the bottom left block -->
<path d="M 115 355
Q 115 345, 125 345
L 245 345
Q 250 345, 250 340
Q 250 345, 255 345
L 375 345
Q 385 345, 385 355"
fill="none" stroke="#444" stroke-width="2"/>
<!-- Arrow from "1" cell down to the center of the brace -->
<path d="M 375 175 Q 375 260, 250 315" class="arrow-line" marker-end="url(#arrowhead)"/>
</g>
<!-- ======== PART 2: Swizzling (from mxfp8_swizzle_indices.svg) ======== -->
<!-- Offset by 330 (300 + 30px gap) -->
<g id="swizzling" transform="translate(100, 330)">
<!-- Left: Sequential indices -->
<g id="sequential">
<!-- Background -->
<rect x="15" y="35" width="270" height="30" class="scale-fill-nostroke"/>
<rect x="15" y="35" width="270" height="30" class="outer-border"/>
<!-- Dashed internal lines -->
<line x1="45" y1="35" x2="45" y2="65" class="inner-line"/>
<line x1="75" y1="35" x2="75" y2="65" class="inner-line"/>
<line x1="105" y1="35" x2="105" y2="65" class="inner-line"/>
<line x1="135" y1="35" x2="135" y2="65" class="inner-line"/>
<line x1="165" y1="35" x2="165" y2="65" class="inner-line"/>
<line x1="195" y1="35" x2="195" y2="65" class="inner-line"/>
<line x1="225" y1="35" x2="225" y2="65" class="inner-line"/>
<line x1="255" y1="35" x2="255" y2="65" class="inner-line"/>
<!-- Numbers -->
<text x="30" y="50" class="num-text-small">0</text>
<text x="60" y="50" class="num-text-small">1</text>
<text x="90" y="50" class="num-text-small">2</text>
<text x="120" y="50" class="num-text-small">3</text>
<text x="150" y="50" class="num-text-small">4</text>
<text x="180" y="50" class="num-text-small">5</text>
<text x="210" y="50" class="num-text-small">6</text>
<text x="240" y="50" class="num-text-small">7</text>
<text x="270" y="50" class="dots-text-small">...</text>
</g>
<!-- Arrow -->
<path d="M 300 50 L 340 50" class="arrow-line" marker-end="url(#arrowhead)"/>
<!-- Right: Swizzled indices -->
<g id="swizzled">
<!-- Background -->
<rect x="360" y="35" width="270" height="30" class="scale-fill-nostroke"/>
<rect x="360" y="35" width="270" height="30" class="outer-border"/>
<!-- Dashed internal lines -->
<line x1="390" y1="35" x2="390" y2="65" class="inner-line"/>
<line x1="420" y1="35" x2="420" y2="65" class="inner-line"/>
<line x1="450" y1="35" x2="450" y2="65" class="inner-line"/>
<line x1="480" y1="35" x2="480" y2="65" class="inner-line"/>
<line x1="510" y1="35" x2="510" y2="65" class="inner-line"/>
<line x1="540" y1="35" x2="540" y2="65" class="inner-line"/>
<line x1="570" y1="35" x2="570" y2="65" class="inner-line"/>
<line x1="600" y1="35" x2="600" y2="65" class="inner-line"/>
<!-- Numbers -->
<text x="375" y="50" class="num-text-small">0</text>
<text x="405" y="50" class="num-text-small">32</text>
<text x="435" y="50" class="num-text-small">64</text>
<text x="465" y="50" class="num-text-small">96</text>
<text x="495" y="50" class="num-text-small">1</text>
<text x="525" y="50" class="num-text-small">33</text>
<text x="555" y="50" class="num-text-small">65</text>
<text x="585" y="50" class="num-text-small">97</text>
<text x="615" y="50" class="dots-text-small">...</text>
</g>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1230 220" width="100%" style="max-width: 900px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
/* Diagram-specific styles */
.input-box { fill: #f3e5f5; stroke: #7b1fa2; stroke-width: 2.5; }
.mxfp8-box { fill: #e3f2fd; stroke: #1976d2; stroke-width: 2.5; }
.fp8-tile { fill: #bbdefb; stroke: #1565c0; stroke-width: 1.5; }
.scale-tile { fill: #a5d6a7; stroke: #388e3c; stroke-width: 1.5; }
.scale-swizzled { fill: #ffb74d; stroke: #e65100; stroke-width: 1.5; }
.swizzle-box { fill: #fff3e0; stroke: #f57c00; stroke-width: 2; }
.quantize-box { fill: #ede7f6; stroke: #5e35b1; stroke-width: 2; }
.comm-box { fill: #fff9c4; stroke: #f57f17; stroke-width: 2; }
.gemm-box { fill: #c8e6c9; stroke: #388e3c; stroke-width: 2; }
/* Arrow override */
.arrow { marker-end: url(#arrowhead); }
</style>
<!-- Arrow marker -->
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- MXFP8 Complete Flow -->
<g id="complete-flow">
<!-- Step 0: Input Tensor -->
<g id="input-fp32-tensor">
<text x="80" y="25" class="text" text-anchor="middle" font-weight="600">Input Tensor</text>
<rect x="20" y="40" width="120" height="150" rx="6" class="input-box"/>
<text x="80" y="120" class="text" text-anchor="middle" fill="#fff" font-weight="600">FP32/BF16</text>
</g>
<!-- Arrow 0 -->
<path d="M 140 115 L 180 115" class="arrow"/>
<!-- Step 1: Quantize -->
<rect x="180" y="75" width="90" height="80" rx="6" class="quantize-box"/>
<text x="225" y="120" class="text" font-weight="600">Quantize</text>
<!-- Arrow 1 -->
<path d="M 270 115 L 310 115" class="arrow"/>
<!-- Step 2: MXFP8 Tensor with sub-tiles stacked vertically -->
<g id="mxfp8-tensor">
<text x="410" y="25" class="text" text-anchor="middle" font-weight="600">MXFP8 Tensor</text>
<rect x="310" y="40" width="200" height="150" rx="6" class="mxfp8-box"/>
<!-- Scales sub-tile (green) - on top -->
<rect x="330" y="55" width="160" height="40" rx="3" class="scale-tile"/>
<text x="410" y="80" class="text" text-anchor="middle" fill="#fff" font-weight="600">Scales</text>
<!-- FP8 Data sub-tile - on bottom -->
<rect x="330" y="105" width="160" height="70" rx="3" class="fp8-tile"/>
<text x="410" y="145" class="text" fill="#fff" font-weight="600">FP8 Data</text>
</g>
<!-- Arrow 2 -->
<path d="M 510 115 L 560 115" class="arrow"/>
<!-- Step 3: Communication -->
<rect x="560" y="75" width="130" height="80" rx="6" class="comm-box"/>
<text x="625" y="110" class="text" font-weight="600">Communication</text>
<text x="625" y="125" class="text" font-size="12">(All-Gather)</text>
<text x="625" y="140" class="text" font-size="12" font-style="italic">(Optional)</text>
<!-- Arrow 3 -->
<path d="M 690 115 L 740 115" class="arrow"/>
<!-- Step 4: Swizzle -->
<rect x="740" y="75" width="110" height="80" rx="6" class="swizzle-box"/>
<text x="795" y="120" class="text" font-weight="600">Swizzle</text>
<!-- Arrow 4 -->
<path d="M 850 115 L 900 115" class="arrow"/>
<!-- Step 5: MXFP8 Tensor with swizzled scales -->
<g id="swizzled-tensor">
<text x="980" y="25" class="text" text-anchor="middle" font-weight="600">MXFP8 Tensor</text>
<rect x="900" y="40" width="160" height="150" rx="6" class="mxfp8-box"/>
<!-- Swizzled Scales sub-tile (orange) - on top -->
<rect x="915" y="55" width="130" height="40" rx="3" class="scale-swizzled"/>
<text x="980" y="80" class="text" text-anchor="middle" fill="#fff" font-weight="600">Swizzle Scales</text>
<!-- FP8 Data sub-tile (unchanged) - on bottom -->
<rect x="915" y="105" width="130" height="70" rx="3" class="fp8-tile"/>
<text x="980" y="145" class="text" fill="#fff" font-weight="600">FP8 Data</text>
</g>
<!-- Arrow 5 -->
<path d="M 1060 115 L 1110 115" class="arrow"/>
<!-- Step 6: GEMM -->
<rect x="1110" y="75" width="110" height="80" rx="6" class="gemm-box"/>
<text x="1165" y="120" class="text" font-weight="600">GEMM</text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 700 300" width="100%" style="max-width: 700px;">
<style>
.tensor-fill { fill: #87CEEB; stroke: #444; stroke-width: 2; }
.scale-fill { fill: #FFA500; stroke: #444; stroke-width: 2; }
.grid-line { stroke: #444; stroke-width: 2; fill: none; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 12px sans-serif; fill: #333; text-anchor: middle; }
</style>
<!-- Left tensor (128x128 blocks) - FP8 tensor -->
<!-- Main rectangle with white background -->
<rect x="60" y="40" width="260" height="240" fill="#FFFFFF" stroke="#444" stroke-width="2"/>
<!-- Blue grid area (upper-left) -->
<rect x="60" y="40" width="180" height="180" fill="#87CEEB" stroke="#444" stroke-width="2"/>
<!-- Grid lines for 3x3 blocks -->
<line x1="120" y1="40" x2="120" y2="220" class="grid-line"/>
<line x1="180" y1="40" x2="180" y2="220" class="grid-line"/>
<line x1="60" y1="100" x2="240" y2="100" class="grid-line"/>
<line x1="60" y1="160" x2="240" y2="160" class="grid-line"/>
<!-- Dots in white area (right side) -->
<text x="280" y="90" class="dots-text"></text>
<text x="280" y="150" class="dots-text"></text>
<text x="280" y="210" class="dots-text"></text>
<!-- Dots in white area (bottom) -->
<text x="90" y="260" class="dots-text"></text>
<text x="150" y="260" class="dots-text"></text>
<text x="210" y="260" class="dots-text"></text>
<text x="280" y="260" class="dots-text"></text>
<!-- Label -->
<text x="190" y="20" class="label">FP8 Tensor (128×128 blocks)</text>
<!-- Right tensor (128x4 blocks) - Scaling factors (orange) -->
<!-- Main rectangle with white background -->
<rect x="480" y="40" width="120" height="240" fill="#FFFFFF" stroke="#444" stroke-width="2"/>
<!-- Orange grid area (upper-left) -->
<rect x="480" y="40" width="60" height="180" fill="#FFA500" stroke="#444" stroke-width="2"/>
<!-- Grid lines for narrow blocks (3 columns x 3 rows) -->
<line x1="500" y1="40" x2="500" y2="220" class="grid-line"/>
<line x1="520" y1="40" x2="520" y2="220" class="grid-line"/>
<line x1="480" y1="100" x2="540" y2="100" class="grid-line"/>
<line x1="480" y1="160" x2="540" y2="160" class="grid-line"/>
<!-- Dots in white area (right side) -->
<text x="565" y="90" class="dots-text"></text>
<text x="565" y="150" class="dots-text"></text>
<text x="565" y="210" class="dots-text"></text>
<!-- Dots in white area (bottom) -->
<text x="500" y="260" class="dots-text"></text>
<text x="530" y="260" class="dots-text"></text>
<text x="565" y="260" class="dots-text"></text>
<!-- Label -->
<text x="540" y="20" class="label">Scaling Factors (128×4 blocks)</text>
</svg>
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Check for Blackwell or newer GPU
from transformer_engine.jax.quantize import get_device_compute_capability
assert (
get_device_compute_capability() >= 100
), f"MXFP8 requires SM100 (Blackwell) or later, got SM{get_device_compute_capability()}"
# START_MXFP8_EXAMPLE
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import DenseGeneral
from transformer_engine.common.recipe import MXFP8BlockScaling, Format
# Create MXFP8 recipe
recipe = MXFP8BlockScaling(
fp8_format=Format.E4M3, # FP8 format (default: E4M3, E5M2 not supported)
)
with te.autocast(enabled=True, recipe=recipe):
# Initialize layer and data
layer = DenseGeneral(features=1024)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
var_collect = layer.init(key, x)
# Forward and backward pass
def loss_fn(var_collect):
output = layer.apply(var_collect, x)
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
# END_MXFP8_EXAMPLE
This diff is collapsed.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Check for Blackwell or newer GPU
major, minor = torch.cuda.get_device_capability()
assert major >= 10, f"MXFP8 requires SM100 (Blackwell) or later, got SM{major}{minor}"
# START_MXFP8_EXAMPLE
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import MXFP8BlockScaling, Format
# Create MXFP8 recipe
recipe = MXFP8BlockScaling(
fp8_format=Format.E4M3, # E4M3 (default) or HYBRID; pure E5M2 not supported
)
# Create a linear layer with bfloat16 parameters
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
# Forward and backward pass
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
loss = output.sum()
loss.backward()
# END_MXFP8_EXAMPLE
This diff is collapsed.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Check for Blackwell or newer GPU
from transformer_engine.jax.quantize import get_device_compute_capability
assert (
get_device_compute_capability() >= 100
), f"NVFP4 requires SM100 (Blackwell) or later, got SM{get_device_compute_capability()}"
# START_NVFP4_EXAMPLE
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import DenseGeneral
from transformer_engine.common.recipe import NVFP4BlockScaling
# Define NVFP4 recipe
# 2D weight quantization and RHT are enabled by default
recipe = NVFP4BlockScaling()
# To disable features, use:
# recipe = NVFP4BlockScaling(disable_rht=True, disable_2d_quantization=True)
with te.autocast(enabled=True, recipe=recipe):
# Initialize layer and data
layer = DenseGeneral(features=1024)
key, sr_key = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
# NVFP4 requires sr_rng for stochastic rounding
rngs = {"sr_rng": sr_key}
var_collect = layer.init({"params": key, "sr_rng": sr_key}, x)
# Forward and backward pass
def loss_fn(var_collect):
output = layer.apply(var_collect, x, rngs=rngs)
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
# END_NVFP4_EXAMPLE
This diff is collapsed.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Check for Blackwell or newer GPU
major, minor = torch.cuda.get_device_capability()
assert major >= 10, f"NVFP4 requires SM100 (Blackwell) or later, got SM{major}{minor}"
# START_NVFP4_EXAMPLE
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import NVFP4BlockScaling
# Define NVFP4 recipe
# 2D weight quantization and RHT are enabled by default
recipe = NVFP4BlockScaling()
# To disable features, use:
# recipe = NVFP4BlockScaling(disable_rht=True, disable_2d_quantization=True)
# Create a linear layer with bfloat16 parameters
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
# Forward and backward pass
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
loss = output.sum()
loss.backward()
# END_NVFP4_EXAMPLE
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment