Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 850 480" width="850" height="480" style="display: block; margin: 0 auto;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.arrow { stroke: #616161; stroke-width: 1.5; fill: none; marker-end: url(#arrowhead); }
.section-label { font-family: 'Segoe UI', Arial, sans-serif; font-size: 16px; font-weight: 600; fill: #424242; text-anchor: start; }
</style>
<marker id="arrowhead" markerWidth="3" markerHeight="3" refX="3" refY="1.5" orient="auto">
<polygon points="0 0, 3 1.5, 0 3" fill="#616161" />
</marker>
</defs>
<!-- Title -->
<text x="425" y="30" class="title" style="text-anchor: middle;">Transformer Layer – default precision of operation in low precision recipe</text>
<!-- Row 1: Input → Layer Norm → QKV Linear → QK^T → Softmax -->
<rect x="20" y="60" width="115" height="50" rx="5" class="hp"/>
<text x="77" y="90" class="text">Input</text>
<path d="M 135 85 L 158 85" class="arrow"/>
<rect x="158" y="60" width="115" height="50" rx="5" class="hp"/>
<text x="215" y="90" class="text">Layer Norm</text>
<path d="M 273 85 L 296 85" class="arrow"/>
<rect x="296" y="60" width="115" height="50" rx="5" class="gemm"/>
<text x="353" y="90" class="text">QKV Linear</text>
<path d="M 411 85 L 434 85" class="arrow"/>
<rect x="434" y="60" width="115" height="50" rx="5" class="hp"/>
<text x="491" y="90" class="text">QK^T</text>
<path d="M 549 85 L 572 85" class="arrow"/>
<rect x="572" y="60" width="115" height="50" rx="5" class="hp"/>
<text x="629" y="90" class="text">Softmax</text>
<!-- Row 2: Attn * V → Output Linear → Dropout + Add -->
<path d="M 629 110 L 629 145" class="arrow"/>
<rect x="572" y="145" width="115" height="50" rx="5" class="hp"/>
<text x="629" y="175" class="text">Scores * V</text>
<path d="M 572 170 L 549 170" class="arrow"/>
<rect x="434" y="145" width="115" height="50" rx="5" class="gemm"/>
<text x="491" y="175" class="text">Output Linear</text>
<path d="M 434 170 L 273 170" class="arrow"/>
<rect x="158" y="145" width="115" height="50" rx="5" class="hp"/>
<text x="215" y="175" class="text">Dropout + Add</text>
<!-- Row 3: Layer Norm → FFN Linear 1 → GELU → FFN Linear 2 → Output -->
<path d="M 215 195 L 215 230" class="arrow"/>
<rect x="158" y="230" width="115" height="50" rx="5" class="hp"/>
<text x="215" y="260" class="text">Layer Norm</text>
<path d="M 273 255 L 296 255" class="arrow"/>
<rect x="296" y="230" width="115" height="50" rx="5" class="gemm"/>
<text x="353" y="260" class="text">FFN Linear 1</text>
<path d="M 411 255 L 434 255" class="arrow"/>
<rect x="434" y="230" width="115" height="50" rx="5" class="hp"/>
<text x="491" y="260" class="text">GELU</text>
<path d="M 549 255 L 572 255" class="arrow"/>
<rect x="572" y="230" width="115" height="50" rx="5" class="gemm"/>
<text x="629" y="260" class="text">FFN Linear 2</text>
<path d="M 687 255 L 710 255" class="arrow"/>
<rect x="710" y="230" width="115" height="50" rx="5" class="hp"/>
<text x="767" y="260" class="text">Output</text>
<!-- Memory State Section -->
<text x="20" y="325" class="section-label">Memory State:</text>
<!-- Parameters -->
<rect x="20" y="340" width="180" height="45" rx="5" class="hp"/>
<text x="110" y="365" class="text">Parameters</text>
<!-- Gradients -->
<rect x="225" y="340" width="140" height="45" rx="5" class="hp"/>
<text x="295" y="365" class="text">Gradients</text>
<!-- Legend -->
<g transform="translate(20, 415)">
<!-- High Precision -->
<rect x="0" y="0" width="80" height="40" rx="5" class="hp"/>
<text x="95" y="23" class="text" style="text-anchor: start;">Higher Precision (FP32/BF16/FP16)</text>
<!-- Low Precision -->
<rect x="400" y="0" width="80" height="40" rx="5" class="gemm"/>
<text x="495" y="23" class="text" style="text-anchor: start;">Lower Precision (FP8, MXFP8 etc.)</text>
</g>
</svg>
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Introduction
===================================
Transformer Engine accelerates deep learning on NVIDIA GPUs in several ways,
with low precision training being one of the most important.
This chapter introduces mixed precision training and FP8 support.
Training in BF16/FP16
---------------------
Deep learning traditionally uses 32-bit floating-point (FP32) numbers.
NVIDIA GPUs support lower precision formats—FP16 since Pascal, BF16 since Ampere—which offer higher throughput and lower memory usage.
Let's compare these formats.
.. raw:: html
:file: img/fp_formats_comparison.svg
*Figure 1: Comparison of FP32, BF16, and FP16 floating-point formats showing bit allocation for sign, exponent, and mantissa.*
The key differences between these formats are:
* **FP32** (32 bits total): 1 sign bit + 8 exponent bits + 23 mantissa bits – standard single-precision format
* **BF16** (16 bits total): 1 sign bit + 8 exponent bits + 7 mantissa bits – maintains FP32's exponent range but has reduced precision
* **FP16** (16 bits total): 1 sign bit + 5 exponent bits + 10 mantissa bits – reduced range but higher precision than BF16
BF16's advantage is that it shares the same exponent range as FP32,
making it easier to convert between the two formats without overflow/underflow issues.
FP16 offers better precision for smaller values but has a limited dynamic range,
which results in the need to perform loss scaling to avoid overflow/underflow—see `this paper on loss scaling <https://arxiv.org/pdf/1710.03740>`__ for more details.
**Mixed precision**
Not all operations should be run in reduced precision to preserve accuracy.
Modern deep learning frameworks use *mixed precision training*,
where different operations use different precisions based on their numerical properties:
* Matrix multiplications are compute-heavy and remain numerically stable at lower precision, making them ideal candidates for acceleration.
* Operations like layer normalization and softmax can work with low precision inputs and outputs, but may use high precision internally or for their weights.
* Operations like loss computation and exponentiation need high precision throughout.
**Master weights**
Another consideration in mixed precision training is how to store the model weights.
Lower precision formats like FP16 and BF16 have limited representational granularity,
which becomes problematic during gradient updates.
When a small gradient is added to a not so small weight stored in low precision,
the result may round back to the original value if the update falls below the format's precision threshold.
Moreover, some elements of the gradient itself can be too small to be represented in low precision,
especially after the accumulation from multiple GPUs in the data parallel training setting.
The solution is to maintain *master weights* in FP32.
During training, weights are cast to lower precision for forward and backward passes,
but the gradient updates are applied to the full-precision master copy.
This ensures that even small gradients accumulate correctly over time.
There are two common software approaches to storing master weights:
* *In the optimizer*:
The model holds low-precision weights,
while the optimizer maintains FP32 copies alongside momentum and other state.
During each step,
the optimizer updates its FP32 copy and casts the result back to the model's low-precision weights.
This approach makes it easier to shard master weights together with other optimizer state, for example in ZeRO optimizer.
Since the casting happens only during the optimizer step, this approach is also faster when optimizer runs less frequently than the model, e.g. when performing gradient accumulation or pipeline parallel training.
* *In the model*:
The model stores weights directly in FP32,
and they are cast to lower precision on-the-fly during forward and backward passes.
This approach works seamlessly with any standard optimizer, requiring no special support.
.. raw:: html
:file: img/master_weights_approaches.svg
*Figure 2: Three approaches to weight storage—low precision only (no master weights), master weights stored in the model, and master weights stored in the optimizer.*
.. tabs::
.. tab:: PyTorch
The PyTorch API of Transformer Engine provides several mechanisms to control precision:
* **Weight precision**: Use the ``params_dtype`` argument in any TE layer constructor.
* **Computation precision**: Use the ``torch.autocast`` context manager. When enabled, inputs are cast to the autocast dtype before computation.
* **Input dtype**: When ``torch.autocast`` is not used, the input tensor's dtype determines the computation precision. In this case, inputs and parameters must have matching dtypes.
.. literalinclude:: bf16_fp16_training_pytorch.py
:language: python
:start-after: # START_BF16_FP16_TRAINING
:end-before: # END_BF16_FP16_TRAINING
.. tab:: JAX
The JAX API of Transformer Engine provides two mechanisms to control precision:
* **Weight precision**: Use the ``dtype`` argument in any TE layer constructor.
* **Computation precision**: Determined by the dtype of the input tensor.
For training with master weights in FP32 and computation in BF16,
cast the input tensor to BF16 before passing it to the layer.
.. literalinclude:: bf16_fp16_training_jax.py
:language: python
:start-after: # START_BF16_FP16_TRAINING
:end-before: # END_BF16_FP16_TRAINING
Lower precisions
----------------
Transformer Engine's primary feature is supporting even lower precision than BF16/FP16, such as FP8, MXFP8, NVFP4, etc.
The logic of these precisions is more complicated than the logic of BF16/FP16 – they require scaling factors to
properly represent the full range of values in the tensor. Sometimes it is one scaling factor per tensor,
sometimes it is one scaling factor per block of values. A precision format combined with the logic for training
is called **a recipe**.
In this section we present common logic for all the recipes. Each one of them is described in more detail in a separate section later.
Let's now see how we can train in lower precisions in supported frameworks.
.. tabs::
.. tab:: PyTorch
The PyTorch API of Transformer Engine provides an ``autocast`` context manager to control precision.
It's similar to the ``torch.autocast`` context manager, but tailored for low precision training.
The most important argument is the ``recipe`` argument, which accepts objects inheriting from
:class:`~transformer_engine.common.recipe.Recipe`.
Forward computations need to be performed inside the ``autocast`` context manager,
while the ``.backward()`` call should be outside of it (it inherits the setting from the
corresponding forward pass).
Here is a basic example:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_pytorch.py
:language: python
:start-after: # START_AUTOCAST_BASIC
:end-before: # END_AUTOCAST_BASIC
You can use multiple recipes in the same model in the following ways:
**Sequential contexts** – apply different recipes to different parts of your model:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_pytorch.py
:language: python
:start-after: # START_AUTOCAST_SEQUENTIAL
:end-before: # END_AUTOCAST_SEQUENTIAL
**Nested contexts** – the inner context overrides the outer one for its scope:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_pytorch.py
:language: python
:start-after: # START_AUTOCAST_NESTED
:end-before: # END_AUTOCAST_NESTED
.. tab:: JAX
The JAX API of Transformer Engine provides an ``autocast`` context manager similar to PyTorch.
The key difference is that in JAX, model initialization must happen inside the ``autocast`` context
to properly capture quantization metadata in the parameter tree.
Here is a basic example:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_jax.py
:language: python
:start-after: # START_AUTOCAST_BASIC
:end-before: # END_AUTOCAST_BASIC
You can use multiple recipes in the same model in the following ways:
**Sequential contexts** – apply different recipes to different parts of your model:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_jax.py
:language: python
:start-after: # START_AUTOCAST_SEQUENTIAL
:end-before: # END_AUTOCAST_SEQUENTIAL
**Nested contexts** – the inner context overrides the outer one for its scope:
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Needs to be run on SM89+ (Ada or newer)
</div>
.. literalinclude:: autocast_jax.py
:language: python
:start-after: # START_AUTOCAST_NESTED
:end-before: # END_AUTOCAST_NESTED
.. note::
Python context managers like ``autocast`` may interact unexpectedly with JAX's JIT compilation.
For finer-grained control, consider passing the recipe directly to TE modules instead.
See the `TE JAX Integration notebook <https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_jax_integration.ipynb>`_
for details.
**Mixed precision with 8- or 4-bit precisions**
From now on, we will refer to FP8/MXFP8/NVFP4 etc. as *low precision*
and to FP32/BF16/FP16 as *high precision*. This terminology will be
used throughout the rest of the documentation.
Not all operations run in low precision:
- **Linear operations**: run in low precision.
- **Attention computations**: run in high precision by default (some recipes allow low precision as an option).
- **Other operations** (layer normalization, softmax, etc.): run in high precision.
Within high-precision operations, there are two categories:
- **Configurable precision**: most operations run in parameter precision (FP32/BF16/FP16) or the precision specified by ``torch.autocast``.
- **Fixed FP32 precision**: some operations, or parts of operations—such as the division in layernorm—always run in FP32, regardless of other settings.
.. raw:: html
:file: img/mixed_precision_operations.svg
*Figure 3: Default precision of operations in a TransformerLayer forward pass. Only linear operations are in lower precision. Dot product attention is shown as three separate operations (QK^T, Softmax, Scores * V), though in practice these may be fused into a single kernel.*
**Linear layer data flow**
Let's see how data flow of a linear layer works by default on a single H100 GPU with FP8 precision:
H100 (Hopper) architecture natively supports FP8 Matrix Multiplication only in **TN** layout (Transpose-NoTranspose),
so GEMM with tensors ``A`` and ``B`` returns ``B * A^T``.
*Forward pass*
* Input is quantized to FP8 – both ``input`` and ``input^T`` quantized versions are created.
* Weights are stored in high precision and quantized to low precision before the GEMM – both ``weight`` and ``weight^T`` quantized versions are created.
* FP8 GEMM with layout **TN** is run with ``weight`` and ``input`` tensors,
* Outputs – ``input * weight^T`` tensor – are returned in high precision.
*Backward pass*
* Output gradients are quantized to FP8 – both ``output_grad`` and ``output_grad^T`` quantized versions are created.
* FP8 GEMM with layout **TN** is performed with ``weight^T`` and ``output_grad`` tensors to compute input gradients.
* FP8 GEMM with layout **TN** is performed with ``input^T`` and ``output_grad^T`` tensors to compute weight gradients.
* Input gradients – ``output_grad * weight`` tensor – are returned in high precision.
* Weight gradients – ``output_grad^T * input`` tensor – are returned in high precision.
.. raw:: html
:file: img/fp8_linear_flow.svg
*Figure 4: Forward pass of a Linear layer with low precision data flow.*
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 60 500 380" width="100%" style="max-width: 500px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 18px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 14px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 12px sans-serif; fill: #555; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
/* Tensor colors */
.fp8-tensor { fill: #87CEEB; stroke: #444; stroke-width: 2; }
.fp8-block { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.fp8-block-alt { fill: #5F9FCC; stroke: #555; stroke-width: 1.5; }
/* Scaling factor colors */
.scale-factor { fill: #FFA500; stroke: #444; stroke-width: 2; }
.grid-line { stroke: #444; stroke-width: 2; }
.boundary-line { stroke: #444; stroke-width: 2; }
</style>
</defs>
<!-- MXFP8 Scaling -->
<g id="mxfp8-scaling">
<text x="250" y="85" class="title">MXFP8</text>
<text x="250" y="108" class="label" style="font-size: 12px;">(One scaling factor per 32 elements)</text>
<!-- FP8 Tensor split into many small blocks (40×10) -->
<g id="tensor-blocks">
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="250.0" y="140.0" width="40" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="130.0" y="190.0" width="240" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Blocks ONLY where they don't overlap with white cross -->
<rect x="130" y="140" width="40" height="10" class="fp8-block"/>
<rect x="170" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="140" width="40" height="10" class="fp8-block"/>
<rect x="290" y="140" width="40" height="10" class="fp8-block"/>
<rect x="330" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="150" width="40" height="10" class="fp8-block"/>
<rect x="210" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="150" width="40" height="10" class="fp8-block"/>
<rect x="130" y="160" width="40" height="10" class="fp8-block"/>
<rect x="170" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="160" width="40" height="10" class="fp8-block"/>
<rect x="290" y="160" width="40" height="10" class="fp8-block"/>
<rect x="330" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="170" width="40" height="10" class="fp8-block"/>
<rect x="210" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="170" width="40" height="10" class="fp8-block"/>
<rect x="130" y="180" width="40" height="10" class="fp8-block"/>
<rect x="170" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="180" width="40" height="10" class="fp8-block"/>
<rect x="290" y="180" width="40" height="10" class="fp8-block"/>
<rect x="330" y="180" width="40" height="10" class="fp8-block-alt"/>
<!-- Three dots in VERTICAL white bar -->
<text x="270" y="167.5" class="dots-text"></text>
<text x="270" y="242.5" class="dots-text"></text>
<!-- Three dots in HORIZONTAL white bar -->
<text x="190" y="205" class="dots-text"></text>
<text x="330" y="205" class="dots-text"></text>
<!-- ONE diagonal dot at intersection -->
<text x="270" y="205" class="dots-text" transform="rotate(45 270 205)"></text>
<!-- Bottom rows (y >= 220 after horizontal white bar) -->
<rect x="130" y="220" width="40" height="10" class="fp8-block"/>
<rect x="170" y="220" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="220" width="40" height="10" class="fp8-block"/>
<rect x="290" y="220" width="40" height="10" class="fp8-block"/>
<rect x="330" y="220" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="230" width="40" height="10" class="fp8-block"/>
<rect x="210" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="230" width="40" height="10" class="fp8-block"/>
<rect x="130" y="240" width="40" height="10" class="fp8-block"/>
<rect x="170" y="240" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="240" width="40" height="10" class="fp8-block"/>
<rect x="290" y="240" width="40" height="10" class="fp8-block"/>
<rect x="330" y="240" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="250" width="40" height="10" class="fp8-block"/>
<rect x="210" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="250" width="40" height="10" class="fp8-block"/>
<!-- Main outline -->
<rect x="130.0" y="140.0" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Scaling factors tensor - 3+2 columns of 10px squares -->
<g id="scale-factors">
<!-- Orange background -->
<rect x="215" y="285" width="70" height="120" fill="#FFA500"/>
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="245" y="285" width="20" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="215" y="335" width="70" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Grid lines showing 10x10 squares (3 left + 2 right columns) -->
<!-- Vertical lines every 10px (skipping white space) -->
<!-- Left 3 columns -->
<line x1="225" y1="285" x2="225" y2="335" class="grid-line" stroke-width="1"/>
<line x1="235" y1="285" x2="235" y2="335" class="grid-line" stroke-width="1"/>
<line x1="245" y1="285" x2="245" y2="335" class="grid-line" stroke-width="1"/>
<!-- Right 2 columns -->
<line x1="265" y1="285" x2="265" y2="335" class="grid-line" stroke-width="1"/>
<line x1="275" y1="285" x2="275" y2="335" class="grid-line" stroke-width="1"/>
<line x1="285" y1="285" x2="285" y2="335" class="grid-line" stroke-width="1"/>
<!-- Bottom sections -->
<line x1="225" y1="365" x2="225" y2="405" class="grid-line" stroke-width="1"/>
<line x1="235" y1="365" x2="235" y2="405" class="grid-line" stroke-width="1"/>
<line x1="245" y1="365" x2="245" y2="405" class="grid-line" stroke-width="1"/>
<line x1="265" y1="365" x2="265" y2="405" class="grid-line" stroke-width="1"/>
<line x1="275" y1="365" x2="275" y2="405" class="grid-line" stroke-width="1"/>
<line x1="285" y1="365" x2="285" y2="405" class="grid-line" stroke-width="1"/>
<!-- Horizontal lines every 10px -->
<line x1="215" y1="295" x2="245" y2="295" class="grid-line" stroke-width="1"/>
<line x1="265" y1="295" x2="285" y2="295" class="grid-line" stroke-width="1"/>
<line x1="215" y1="305" x2="245" y2="305" class="grid-line" stroke-width="1"/>
<line x1="265" y1="305" x2="285" y2="305" class="grid-line" stroke-width="1"/>
<line x1="215" y1="315" x2="245" y2="315" class="grid-line" stroke-width="1"/>
<line x1="265" y1="315" x2="285" y2="315" class="grid-line" stroke-width="1"/>
<line x1="215" y1="325" x2="245" y2="325" class="grid-line" stroke-width="1"/>
<line x1="265" y1="325" x2="285" y2="325" class="grid-line" stroke-width="1"/>
<!-- Top bottom boundaries -->
<line x1="215" y1="335" x2="245" y2="335" class="grid-line" stroke-width="1"/>
<line x1="265" y1="335" x2="285" y2="335" class="grid-line" stroke-width="1"/>
<line x1="215" y1="365" x2="245" y2="365" class="grid-line" stroke-width="1"/>
<line x1="265" y1="365" x2="285" y2="365" class="grid-line" stroke-width="1"/>
<line x1="215" y1="375" x2="245" y2="375" class="grid-line" stroke-width="1"/>
<line x1="265" y1="375" x2="285" y2="375" class="grid-line" stroke-width="1"/>
<line x1="215" y1="385" x2="245" y2="385" class="grid-line" stroke-width="1"/>
<line x1="265" y1="385" x2="285" y2="385" class="grid-line" stroke-width="1"/>
<line x1="215" y1="395" x2="245" y2="395" class="grid-line" stroke-width="1"/>
<line x1="265" y1="395" x2="285" y2="395" class="grid-line" stroke-width="1"/>
<!-- Bottom boundaries -->
<line x1="215" y1="405" x2="245" y2="405" class="grid-line" stroke-width="1"/>
<line x1="265" y1="405" x2="285" y2="405" class="grid-line" stroke-width="1"/>
<!-- Main outline -->
<rect x="215" y="285" width="70" height="120" fill="none" stroke="#444" stroke-width="2"/>
<!-- Three dots -->
<text x="255" y="312.5" class="dots-text" style="font-size: 14px;"></text>
<text x="255" y="387.5" class="dots-text" style="font-size: 14px;"></text>
<text x="230" y="350" class="dots-text" style="font-size: 14px;"></text>
<text x="275" y="350" class="dots-text" style="font-size: 14px;"></text>
<text x="255" y="350" class="dots-text" style="font-size: 14px;" transform="rotate(45 255 350)"></text>
</g>
<text x="250" y="430" class="small-text" text-anchor="middle">E8M0 scaling factors (one per 32 elements)</text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 650 450" width="100%" style="max-width: 650px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 18px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 14px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 12px sans-serif; fill: #555; text-anchor: middle; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
.fp8-block { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.fp8-block-alt { fill: #5F9FCC; stroke: #555; stroke-width: 1.5; }
.scale-block { fill: #FFA500; stroke: #555; stroke-width: 1.5; }
</style>
</defs>
<!-- ROWWISE SECTION -->
<text x="325" y="30" class="title">Rowwise (1x32 blocks)</text>
<!-- Rowwise Data Tensor -->
<g id="rowwise-tensor">
<text x="160" y="55" class="small-text">Data</text>
<rect x="40" y="70" width="40" height="10" class="fp8-block"/>
<rect x="80" y="70" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="70" width="40" height="10" class="fp8-block"/>
<rect x="200" y="70" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="70" width="40" height="10" class="fp8-block"/>
<rect x="40" y="80" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="80" width="40" height="10" class="fp8-block"/>
<rect x="120" y="80" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="80" width="40" height="10" class="fp8-block"/>
<rect x="240" y="80" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="90" width="40" height="10" class="fp8-block"/>
<rect x="80" y="90" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="90" width="40" height="10" class="fp8-block"/>
<rect x="200" y="90" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="90" width="40" height="10" class="fp8-block"/>
<rect x="40" y="100" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="100" width="40" height="10" class="fp8-block"/>
<rect x="120" y="100" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="100" width="40" height="10" class="fp8-block"/>
<rect x="240" y="100" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="110" width="40" height="10" class="fp8-block"/>
<rect x="80" y="110" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="110" width="40" height="10" class="fp8-block"/>
<rect x="200" y="110" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="110" width="40" height="10" class="fp8-block"/>
<rect x="40" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="150" width="40" height="10" class="fp8-block"/>
<rect x="120" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="150" width="40" height="10" class="fp8-block"/>
<rect x="240" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="160" width="40" height="10" class="fp8-block"/>
<rect x="80" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="160" width="40" height="10" class="fp8-block"/>
<rect x="200" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="160" width="40" height="10" class="fp8-block"/>
<rect x="40" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="170" width="40" height="10" class="fp8-block"/>
<rect x="120" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="170" width="40" height="10" class="fp8-block"/>
<rect x="240" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="180" width="40" height="10" class="fp8-block"/>
<rect x="80" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="180" width="40" height="10" class="fp8-block"/>
<rect x="200" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="180" width="40" height="10" class="fp8-block"/>
<rect x="40" y="70" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
<text x="180" y="97.5" class="dots-text"></text>
<text x="180" y="172.5" class="dots-text"></text>
<text x="100" y="135" class="dots-text"></text>
<text x="240" y="135" class="dots-text"></text>
<text x="180" y="135" class="dots-text" transform="rotate(45 180 135)"></text>
</g>
<!-- Rowwise Scale Tensor -->
<g id="rowwise-scales">
<text x="485" y="55" class="small-text">Scales</text>
<!-- Rows 1-5 -->
<rect x="380" y="70" width="10" height="10" class="scale-block"/>
<rect x="390" y="70" width="10" height="10" class="scale-block"/>
<rect x="400" y="70" width="10" height="10" class="scale-block"/>
<rect x="450" y="70" width="10" height="10" class="scale-block"/>
<rect x="460" y="70" width="10" height="10" class="scale-block"/>
<rect x="380" y="80" width="10" height="10" class="scale-block"/>
<rect x="390" y="80" width="10" height="10" class="scale-block"/>
<rect x="400" y="80" width="10" height="10" class="scale-block"/>
<rect x="450" y="80" width="10" height="10" class="scale-block"/>
<rect x="460" y="80" width="10" height="10" class="scale-block"/>
<rect x="380" y="90" width="10" height="10" class="scale-block"/>
<rect x="390" y="90" width="10" height="10" class="scale-block"/>
<rect x="400" y="90" width="10" height="10" class="scale-block"/>
<rect x="450" y="90" width="10" height="10" class="scale-block"/>
<rect x="460" y="90" width="10" height="10" class="scale-block"/>
<rect x="380" y="100" width="10" height="10" class="scale-block"/>
<rect x="390" y="100" width="10" height="10" class="scale-block"/>
<rect x="400" y="100" width="10" height="10" class="scale-block"/>
<rect x="450" y="100" width="10" height="10" class="scale-block"/>
<rect x="460" y="100" width="10" height="10" class="scale-block"/>
<rect x="380" y="110" width="10" height="10" class="scale-block"/>
<rect x="390" y="110" width="10" height="10" class="scale-block"/>
<rect x="400" y="110" width="10" height="10" class="scale-block"/>
<rect x="450" y="110" width="10" height="10" class="scale-block"/>
<rect x="460" y="110" width="10" height="10" class="scale-block"/>
<!-- Gap rows -->
<rect x="380" y="150" width="10" height="10" class="scale-block"/>
<rect x="390" y="150" width="10" height="10" class="scale-block"/>
<rect x="400" y="150" width="10" height="10" class="scale-block"/>
<rect x="450" y="150" width="10" height="10" class="scale-block"/>
<rect x="460" y="150" width="10" height="10" class="scale-block"/>
<rect x="380" y="160" width="10" height="10" class="scale-block"/>
<rect x="390" y="160" width="10" height="10" class="scale-block"/>
<rect x="400" y="160" width="10" height="10" class="scale-block"/>
<rect x="450" y="160" width="10" height="10" class="scale-block"/>
<rect x="460" y="160" width="10" height="10" class="scale-block"/>
<rect x="380" y="170" width="10" height="10" class="scale-block"/>
<rect x="390" y="170" width="10" height="10" class="scale-block"/>
<rect x="400" y="170" width="10" height="10" class="scale-block"/>
<rect x="450" y="170" width="10" height="10" class="scale-block"/>
<rect x="460" y="170" width="10" height="10" class="scale-block"/>
<rect x="380" y="180" width="10" height="10" class="scale-block"/>
<rect x="390" y="180" width="10" height="10" class="scale-block"/>
<rect x="400" y="180" width="10" height="10" class="scale-block"/>
<rect x="450" y="180" width="10" height="10" class="scale-block"/>
<rect x="460" y="180" width="10" height="10" class="scale-block"/>
<rect x="380" y="70" width="90" height="120" fill="none" stroke="#444" stroke-width="2"/>
<text x="430" y="97.5" class="dots-text"></text>
<text x="430" y="172.5" class="dots-text"></text>
<text x="400" y="135" class="dots-text"></text>
<text x="460" y="135" class="dots-text"></text>
<text x="430" y="135" class="dots-text" transform="rotate(45 430 135)"></text>
</g>
<!-- COLUMNWISE SECTION -->
<text x="325" y="230" class="title">Columnwise (32x1 blocks)</text>
<!-- Columnwise Data Tensor -->
<g id="colwise-tensor">
<text x="160" y="255" class="small-text">Data</text>
<rect x="40" y="270" width="10" height="40" class="fp8-block"/>
<rect x="50" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="60" y="270" width="10" height="40" class="fp8-block"/>
<rect x="70" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="80" y="270" width="10" height="40" class="fp8-block"/>
<rect x="90" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="100" y="270" width="10" height="40" class="fp8-block"/>
<rect x="110" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="120" y="270" width="10" height="40" class="fp8-block"/>
<rect x="130" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="180" y="270" width="10" height="40" class="fp8-block"/>
<rect x="190" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="200" y="270" width="10" height="40" class="fp8-block"/>
<rect x="210" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="220" y="270" width="10" height="40" class="fp8-block"/>
<rect x="230" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="240" y="270" width="10" height="40" class="fp8-block"/>
<rect x="250" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="260" y="270" width="10" height="40" class="fp8-block"/>
<rect x="270" y="270" width="10" height="40" class="fp8-block-alt"/>
<rect x="40" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="50" y="350" width="10" height="40" class="fp8-block"/>
<rect x="60" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="70" y="350" width="10" height="40" class="fp8-block"/>
<rect x="80" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="90" y="350" width="10" height="40" class="fp8-block"/>
<rect x="100" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="110" y="350" width="10" height="40" class="fp8-block"/>
<rect x="120" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="130" y="350" width="10" height="40" class="fp8-block"/>
<rect x="180" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="190" y="350" width="10" height="40" class="fp8-block"/>
<rect x="200" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="210" y="350" width="10" height="40" class="fp8-block"/>
<rect x="220" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="230" y="350" width="10" height="40" class="fp8-block"/>
<rect x="240" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="250" y="350" width="10" height="40" class="fp8-block"/>
<rect x="260" y="350" width="10" height="40" class="fp8-block-alt"/>
<rect x="270" y="350" width="10" height="40" class="fp8-block"/>
<rect x="40" y="270" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
<text x="160" y="296" class="dots-text"></text>
<text x="160" y="376" class="dots-text"></text>
<text x="90" y="336" class="dots-text" transform="rotate(90 90 336)"></text>
<text x="230" y="336" class="dots-text" transform="rotate(90 230 336)"></text>
<text x="160" y="336" class="dots-text" transform="rotate(45 160 336)"></text>
</g>
<!-- Columnwise Scale Tensor - TRANSPOSED -->
<g id="colwise-scales">
<text x="485" y="255" class="small-text">Scales</text>
<!-- Row 1 -->
<rect x="370" y="300" width="10" height="10" class="scale-block"/>
<rect x="380" y="300" width="10" height="10" class="scale-block"/>
<rect x="390" y="300" width="10" height="10" class="scale-block"/>
<rect x="400" y="300" width="10" height="10" class="scale-block"/>
<rect x="410" y="300" width="10" height="10" class="scale-block"/>
<rect x="420" y="300" width="10" height="10" class="scale-block"/>
<rect x="430" y="300" width="10" height="10" class="scale-block"/>
<rect x="440" y="300" width="10" height="10" class="scale-block"/>
<rect x="450" y="300" width="10" height="10" class="scale-block"/>
<rect x="460" y="300" width="10" height="10" class="scale-block"/>
<rect x="510" y="300" width="10" height="10" class="scale-block"/>
<rect x="520" y="300" width="10" height="10" class="scale-block"/>
<rect x="530" y="300" width="10" height="10" class="scale-block"/>
<rect x="540" y="300" width="10" height="10" class="scale-block"/>
<rect x="550" y="300" width="10" height="10" class="scale-block"/>
<rect x="560" y="300" width="10" height="10" class="scale-block"/>
<rect x="570" y="300" width="10" height="10" class="scale-block"/>
<rect x="580" y="300" width="10" height="10" class="scale-block"/>
<rect x="590" y="300" width="10" height="10" class="scale-block"/>
<rect x="600" y="300" width="10" height="10" class="scale-block"/>
<!-- Row 2 (gap) -->
<rect x="370" y="330" width="10" height="10" class="scale-block"/>
<rect x="380" y="330" width="10" height="10" class="scale-block"/>
<rect x="390" y="330" width="10" height="10" class="scale-block"/>
<rect x="400" y="330" width="10" height="10" class="scale-block"/>
<rect x="410" y="330" width="10" height="10" class="scale-block"/>
<rect x="420" y="330" width="10" height="10" class="scale-block"/>
<rect x="430" y="330" width="10" height="10" class="scale-block"/>
<rect x="440" y="330" width="10" height="10" class="scale-block"/>
<rect x="450" y="330" width="10" height="10" class="scale-block"/>
<rect x="460" y="330" width="10" height="10" class="scale-block"/>
<rect x="510" y="330" width="10" height="10" class="scale-block"/>
<rect x="520" y="330" width="10" height="10" class="scale-block"/>
<rect x="530" y="330" width="10" height="10" class="scale-block"/>
<rect x="540" y="330" width="10" height="10" class="scale-block"/>
<rect x="550" y="330" width="10" height="10" class="scale-block"/>
<rect x="560" y="330" width="10" height="10" class="scale-block"/>
<rect x="570" y="330" width="10" height="10" class="scale-block"/>
<rect x="580" y="330" width="10" height="10" class="scale-block"/>
<rect x="590" y="330" width="10" height="10" class="scale-block"/>
<rect x="600" y="330" width="10" height="10" class="scale-block"/>
<rect x="370" y="300" width="240" height="40" fill="none" stroke="#444" stroke-width="2"/>
<text x="490" y="320" class="dots-text"></text>
<text x="430" y="320" class="dots-text"></text>
<text x="560" y="320" class="dots-text"></text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 -15 900 445" width="100%" style="max-width: 900px;">
<style>
.scale-fill { fill: #FFA500; stroke: #444; stroke-width: 2; }
.scale-fill-nostroke { fill: #FFA500; stroke: none; }
.grid-line { stroke: #444; stroke-width: 2; fill: none; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 12px sans-serif; fill: #333; text-anchor: middle; }
.num-text { font: bold 12px sans-serif; fill: #333; text-anchor: middle; dominant-baseline: middle; }
.arrow-line { stroke: #444; stroke-width: 2; fill: none; }
.outer-border { fill: none; stroke: #444; stroke-width: 1.5; }
.inner-line { stroke: #444; stroke-width: 1; stroke-dasharray: 3,2; }
.num-text-small { font: bold 11px sans-serif; fill: #333; text-anchor: middle; dominant-baseline: middle; }
.dots-text-small { font: bold 14px sans-serif; fill: #333; text-anchor: middle; dominant-baseline: middle; }
</style>
<defs>
<marker id="arrowhead" markerWidth="10" markerHeight="10" refX="8" refY="5" orient="auto">
<polygon points="0 0, 10 5, 0 10" fill="#444" />
</marker>
</defs>
<!-- ======== PART 1: Linearization (from mxfp8_scale_linearize.svg) ======== -->
<g id="linearization">
<!-- Left: Scaling factors grid -->
<!-- Main rectangle with white background -->
<rect x="40" y="40" width="120" height="240" fill="#FFFFFF" stroke="#444" stroke-width="2"/>
<!-- Orange grid area (upper-left) -->
<rect x="40" y="40" width="60" height="180" fill="#FFA500" stroke="#444" stroke-width="2"/>
<!-- Grid lines for narrow blocks (3 columns x 3 rows) -->
<line x1="60" y1="40" x2="60" y2="220" class="grid-line"/>
<line x1="80" y1="40" x2="80" y2="220" class="grid-line"/>
<line x1="40" y1="100" x2="100" y2="100" class="grid-line"/>
<line x1="40" y1="160" x2="100" y2="160" class="grid-line"/>
<!-- Numbers in orange cells -->
<!-- Row 1 -->
<text x="50" y="70" class="num-text">1</text>
<text x="70" y="70" class="num-text">2</text>
<text x="90" y="70" class="num-text">3</text>
<!-- Row 2 -->
<text x="50" y="125" class="num-text">K</text>
<text x="50" y="137" class="num-text">+</text>
<text x="50" y="149" class="num-text">1</text>
<text x="70" y="125" class="num-text">K</text>
<text x="70" y="137" class="num-text">+</text>
<text x="70" y="149" class="num-text">2</text>
<text x="90" y="125" class="num-text">K</text>
<text x="90" y="137" class="num-text">+</text>
<text x="90" y="149" class="num-text">3</text>
<!-- Row 3 -->
<text x="50" y="180" class="num-text">2K</text>
<text x="50" y="192" class="num-text">+</text>
<text x="50" y="204" class="num-text">1</text>
<text x="70" y="180" class="num-text">2K</text>
<text x="70" y="192" class="num-text">+</text>
<text x="70" y="204" class="num-text">1</text>
<text x="90" y="180" class="num-text">2K</text>
<text x="90" y="192" class="num-text">+</text>
<text x="90" y="204" class="num-text">3</text>
<!-- Dots in white area (right side) -->
<text x="125" y="90" class="dots-text"></text>
<text x="125" y="150" class="dots-text"></text>
<text x="125" y="210" class="dots-text"></text>
<!-- Dots in white area (bottom) -->
<text x="60" y="260" class="dots-text"></text>
<text x="90" y="260" class="dots-text"></text>
<text x="125" y="260" class="dots-text"></text>
<!-- Arrow pointing to first block with label 128x4 -->
<text x="50" y="0" class="label" text-anchor="middle">128x4</text>
<path d="M 50 5 L 50 38" class="arrow-line" marker-end="url(#arrowhead)"/>
<!-- Arrow -->
<path d="M 200 150 L 300 150" class="arrow-line" marker-end="url(#arrowhead)"/>
<!-- Right: Linearized 1D array -->
<!-- Main rectangle with white background -->
<rect x="340" y="140" width="520" height="20" fill="#FFFFFF" stroke="#444" stroke-width="2"/>
<!-- Orange cells -->
<rect x="340" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="375" y="150" class="num-text">1</text>
<rect x="410" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="445" y="150" class="num-text">2</text>
<rect x="480" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="515" y="150" class="dots-text"></text>
<rect x="550" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="585" y="150" class="num-text">K + 1</text>
<rect x="620" y="140" width="70" height="20" fill="#FFA500" stroke="#444" stroke-width="2"/>
<text x="655" y="150" class="num-text">K + 2</text>
<!-- White area with dots -->
<text x="725" y="150" class="dots-text"></text>
<!-- Arrow pointing to first linearized block with label 1x512 -->
<text x="375" y="75" class="label" text-anchor="middle">1x512</text>
<path d="M 375 80 L 375 138" class="arrow-line" marker-end="url(#arrowhead)"/>
</g>
<!-- ======== Connection: Arrow from "1" to bottom left block with brace ======== -->
<g id="connection">
<!-- Label above brace -->
<text x="250" y="335" class="label">128 4-bit elements</text>
<!-- Curly brace on top of the bottom left block -->
<path d="M 115 355
Q 115 345, 125 345
L 245 345
Q 250 345, 250 340
Q 250 345, 255 345
L 375 345
Q 385 345, 385 355"
fill="none" stroke="#444" stroke-width="2"/>
<!-- Arrow from "1" cell down to the center of the brace -->
<path d="M 375 175 Q 375 260, 250 315" class="arrow-line" marker-end="url(#arrowhead)"/>
</g>
<!-- ======== PART 2: Swizzling (from mxfp8_swizzle_indices.svg) ======== -->
<!-- Offset by 330 (300 + 30px gap) -->
<g id="swizzling" transform="translate(100, 330)">
<!-- Left: Sequential indices -->
<g id="sequential">
<!-- Background -->
<rect x="15" y="35" width="270" height="30" class="scale-fill-nostroke"/>
<rect x="15" y="35" width="270" height="30" class="outer-border"/>
<!-- Dashed internal lines -->
<line x1="45" y1="35" x2="45" y2="65" class="inner-line"/>
<line x1="75" y1="35" x2="75" y2="65" class="inner-line"/>
<line x1="105" y1="35" x2="105" y2="65" class="inner-line"/>
<line x1="135" y1="35" x2="135" y2="65" class="inner-line"/>
<line x1="165" y1="35" x2="165" y2="65" class="inner-line"/>
<line x1="195" y1="35" x2="195" y2="65" class="inner-line"/>
<line x1="225" y1="35" x2="225" y2="65" class="inner-line"/>
<line x1="255" y1="35" x2="255" y2="65" class="inner-line"/>
<!-- Numbers -->
<text x="30" y="50" class="num-text-small">0</text>
<text x="60" y="50" class="num-text-small">1</text>
<text x="90" y="50" class="num-text-small">2</text>
<text x="120" y="50" class="num-text-small">3</text>
<text x="150" y="50" class="num-text-small">4</text>
<text x="180" y="50" class="num-text-small">5</text>
<text x="210" y="50" class="num-text-small">6</text>
<text x="240" y="50" class="num-text-small">7</text>
<text x="270" y="50" class="dots-text-small">...</text>
</g>
<!-- Arrow -->
<path d="M 300 50 L 340 50" class="arrow-line" marker-end="url(#arrowhead)"/>
<!-- Right: Swizzled indices -->
<g id="swizzled">
<!-- Background -->
<rect x="360" y="35" width="270" height="30" class="scale-fill-nostroke"/>
<rect x="360" y="35" width="270" height="30" class="outer-border"/>
<!-- Dashed internal lines -->
<line x1="390" y1="35" x2="390" y2="65" class="inner-line"/>
<line x1="420" y1="35" x2="420" y2="65" class="inner-line"/>
<line x1="450" y1="35" x2="450" y2="65" class="inner-line"/>
<line x1="480" y1="35" x2="480" y2="65" class="inner-line"/>
<line x1="510" y1="35" x2="510" y2="65" class="inner-line"/>
<line x1="540" y1="35" x2="540" y2="65" class="inner-line"/>
<line x1="570" y1="35" x2="570" y2="65" class="inner-line"/>
<line x1="600" y1="35" x2="600" y2="65" class="inner-line"/>
<!-- Numbers -->
<text x="375" y="50" class="num-text-small">0</text>
<text x="405" y="50" class="num-text-small">32</text>
<text x="435" y="50" class="num-text-small">64</text>
<text x="465" y="50" class="num-text-small">96</text>
<text x="495" y="50" class="num-text-small">1</text>
<text x="525" y="50" class="num-text-small">33</text>
<text x="555" y="50" class="num-text-small">65</text>
<text x="585" y="50" class="num-text-small">97</text>
<text x="615" y="50" class="dots-text-small">...</text>
</g>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1230 220" width="100%" style="max-width: 900px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
/* Diagram-specific styles */
.input-box { fill: #f3e5f5; stroke: #7b1fa2; stroke-width: 2.5; }
.mxfp8-box { fill: #e3f2fd; stroke: #1976d2; stroke-width: 2.5; }
.fp8-tile { fill: #bbdefb; stroke: #1565c0; stroke-width: 1.5; }
.scale-tile { fill: #a5d6a7; stroke: #388e3c; stroke-width: 1.5; }
.scale-swizzled { fill: #ffb74d; stroke: #e65100; stroke-width: 1.5; }
.swizzle-box { fill: #fff3e0; stroke: #f57c00; stroke-width: 2; }
.quantize-box { fill: #ede7f6; stroke: #5e35b1; stroke-width: 2; }
.comm-box { fill: #fff9c4; stroke: #f57f17; stroke-width: 2; }
.gemm-box { fill: #c8e6c9; stroke: #388e3c; stroke-width: 2; }
/* Arrow override */
.arrow { marker-end: url(#arrowhead); }
</style>
<!-- Arrow marker -->
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- MXFP8 Complete Flow -->
<g id="complete-flow">
<!-- Step 0: Input Tensor -->
<g id="input-fp32-tensor">
<text x="80" y="25" class="text" text-anchor="middle" font-weight="600">Input Tensor</text>
<rect x="20" y="40" width="120" height="150" rx="6" class="input-box"/>
<text x="80" y="120" class="text" text-anchor="middle" fill="#fff" font-weight="600">FP32/BF16</text>
</g>
<!-- Arrow 0 -->
<path d="M 140 115 L 180 115" class="arrow"/>
<!-- Step 1: Quantize -->
<rect x="180" y="75" width="90" height="80" rx="6" class="quantize-box"/>
<text x="225" y="120" class="text" font-weight="600">Quantize</text>
<!-- Arrow 1 -->
<path d="M 270 115 L 310 115" class="arrow"/>
<!-- Step 2: MXFP8 Tensor with sub-tiles stacked vertically -->
<g id="mxfp8-tensor">
<text x="410" y="25" class="text" text-anchor="middle" font-weight="600">MXFP8 Tensor</text>
<rect x="310" y="40" width="200" height="150" rx="6" class="mxfp8-box"/>
<!-- Scales sub-tile (green) - on top -->
<rect x="330" y="55" width="160" height="40" rx="3" class="scale-tile"/>
<text x="410" y="80" class="text" text-anchor="middle" fill="#fff" font-weight="600">Scales</text>
<!-- FP8 Data sub-tile - on bottom -->
<rect x="330" y="105" width="160" height="70" rx="3" class="fp8-tile"/>
<text x="410" y="145" class="text" fill="#fff" font-weight="600">FP8 Data</text>
</g>
<!-- Arrow 2 -->
<path d="M 510 115 L 560 115" class="arrow"/>
<!-- Step 3: Communication -->
<rect x="560" y="75" width="130" height="80" rx="6" class="comm-box"/>
<text x="625" y="110" class="text" font-weight="600">Communication</text>
<text x="625" y="125" class="text" font-size="12">(All-Gather)</text>
<text x="625" y="140" class="text" font-size="12" font-style="italic">(Optional)</text>
<!-- Arrow 3 -->
<path d="M 690 115 L 740 115" class="arrow"/>
<!-- Step 4: Swizzle -->
<rect x="740" y="75" width="110" height="80" rx="6" class="swizzle-box"/>
<text x="795" y="120" class="text" font-weight="600">Swizzle</text>
<!-- Arrow 4 -->
<path d="M 850 115 L 900 115" class="arrow"/>
<!-- Step 5: MXFP8 Tensor with swizzled scales -->
<g id="swizzled-tensor">
<text x="980" y="25" class="text" text-anchor="middle" font-weight="600">MXFP8 Tensor</text>
<rect x="900" y="40" width="160" height="150" rx="6" class="mxfp8-box"/>
<!-- Swizzled Scales sub-tile (orange) - on top -->
<rect x="915" y="55" width="130" height="40" rx="3" class="scale-swizzled"/>
<text x="980" y="80" class="text" text-anchor="middle" fill="#fff" font-weight="600">Swizzle Scales</text>
<!-- FP8 Data sub-tile (unchanged) - on bottom -->
<rect x="915" y="105" width="130" height="70" rx="3" class="fp8-tile"/>
<text x="980" y="145" class="text" fill="#fff" font-weight="600">FP8 Data</text>
</g>
<!-- Arrow 5 -->
<path d="M 1060 115 L 1110 115" class="arrow"/>
<!-- Step 6: GEMM -->
<rect x="1110" y="75" width="110" height="80" rx="6" class="gemm-box"/>
<text x="1165" y="120" class="text" font-weight="600">GEMM</text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 700 300" width="100%" style="max-width: 700px;">
<style>
.tensor-fill { fill: #87CEEB; stroke: #444; stroke-width: 2; }
.scale-fill { fill: #FFA500; stroke: #444; stroke-width: 2; }
.grid-line { stroke: #444; stroke-width: 2; fill: none; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 12px sans-serif; fill: #333; text-anchor: middle; }
</style>
<!-- Left tensor (128x128 blocks) - FP8 tensor -->
<!-- Main rectangle with white background -->
<rect x="60" y="40" width="260" height="240" fill="#FFFFFF" stroke="#444" stroke-width="2"/>
<!-- Blue grid area (upper-left) -->
<rect x="60" y="40" width="180" height="180" fill="#87CEEB" stroke="#444" stroke-width="2"/>
<!-- Grid lines for 3x3 blocks -->
<line x1="120" y1="40" x2="120" y2="220" class="grid-line"/>
<line x1="180" y1="40" x2="180" y2="220" class="grid-line"/>
<line x1="60" y1="100" x2="240" y2="100" class="grid-line"/>
<line x1="60" y1="160" x2="240" y2="160" class="grid-line"/>
<!-- Dots in white area (right side) -->
<text x="280" y="90" class="dots-text"></text>
<text x="280" y="150" class="dots-text"></text>
<text x="280" y="210" class="dots-text"></text>
<!-- Dots in white area (bottom) -->
<text x="90" y="260" class="dots-text"></text>
<text x="150" y="260" class="dots-text"></text>
<text x="210" y="260" class="dots-text"></text>
<text x="280" y="260" class="dots-text"></text>
<!-- Label -->
<text x="190" y="20" class="label">FP8 Tensor (128×128 blocks)</text>
<!-- Right tensor (128x4 blocks) - Scaling factors (orange) -->
<!-- Main rectangle with white background -->
<rect x="480" y="40" width="120" height="240" fill="#FFFFFF" stroke="#444" stroke-width="2"/>
<!-- Orange grid area (upper-left) -->
<rect x="480" y="40" width="60" height="180" fill="#FFA500" stroke="#444" stroke-width="2"/>
<!-- Grid lines for narrow blocks (3 columns x 3 rows) -->
<line x1="500" y1="40" x2="500" y2="220" class="grid-line"/>
<line x1="520" y1="40" x2="520" y2="220" class="grid-line"/>
<line x1="480" y1="100" x2="540" y2="100" class="grid-line"/>
<line x1="480" y1="160" x2="540" y2="160" class="grid-line"/>
<!-- Dots in white area (right side) -->
<text x="565" y="90" class="dots-text"></text>
<text x="565" y="150" class="dots-text"></text>
<text x="565" y="210" class="dots-text"></text>
<!-- Dots in white area (bottom) -->
<text x="500" y="260" class="dots-text"></text>
<text x="530" y="260" class="dots-text"></text>
<text x="565" y="260" class="dots-text"></text>
<!-- Label -->
<text x="540" y="20" class="label">Scaling Factors (128×4 blocks)</text>
</svg>
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Check for Blackwell or newer GPU
from transformer_engine.jax.quantize import get_device_compute_capability
assert (
get_device_compute_capability() >= 100
), f"MXFP8 requires SM100 (Blackwell) or later, got SM{get_device_compute_capability()}"
# START_MXFP8_EXAMPLE
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import DenseGeneral
from transformer_engine.common.recipe import MXFP8BlockScaling, Format
# Create MXFP8 recipe
recipe = MXFP8BlockScaling(
fp8_format=Format.E4M3, # FP8 format (default: E4M3, E5M2 not supported)
)
with te.autocast(enabled=True, recipe=recipe):
# Initialize layer and data
layer = DenseGeneral(features=1024)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
var_collect = layer.init(key, x)
# Forward and backward pass
def loss_fn(var_collect):
output = layer.apply(var_collect, x)
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
# END_MXFP8_EXAMPLE
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
MXFP8
=====
MXFP8 (Microscaling FP8) is an enhanced FP8 blockwise scaling recipe that leverages native hardware
acceleration on Blackwell GPUs (SM 10.0+). By using one scaling factor per 32 consecutive values
(rather than 128), MXFP8 delivers finer-grained quantization with improved numerical precision.
Data Format
-----------
The representation of an FP8 tensor element ``x`` in MXFP8 precision is given by:
.. code-block:: python
x = x_fp8 * s_block
where
* ``x_fp8`` is the FP8 value in E4M3 format,
* ``s_block`` is a local **E8M0** scaling factor shared by a block of 32 elements.
E8M0 is an 8-bit format with 8 exponent bits and 0 mantissa bits, representing only powers of 2.
**FP8 format**
Like FP8 Blockwise Scaling, E4M3 is used by default for both forward and backward passes.
The finer-grained scaling provides sufficient dynamic range without requiring the E5M2 format.
The ``fp8_format`` parameter also supports ``HYBRID`` mode (E4M3 for forward, E5M2 for backward).
Pure E5M2 training is not supported.
**Block size**
Block size is 32.
Blocks are one-dimensional, containing 32 consecutive values. No 2D scaling is performed.
There are some assumptions on the dimensions of the tensor:
* the tensor must have at least 2 dimensions,
* the last dimension must be divisible by 32,
* the product of all dimensions except the last must be divisible by 32.
**Scaling factors**
Scaling factors are stored as E8M0 (8 exponent bits, 0 mantissa bits), which inherently represents
powers of 2. This differs from FP8 Blockwise Scaling, which uses 32-bit floating point numbers
optionally constrained to powers of 2. Note that FP32 also has 8 exponent bits, so the representable
ranges are the same when the power-of-2 constraint is enabled.
Each block's scaling factor is computed through the following steps:
1. Find the maximum absolute value (``amax_block``) across all 32 elements in the block.
2. Compute the E8M0 biased exponent: ``e = float_to_e8m0(amax_block / max_fp8)``, where ``max_fp8 = 448``
(the maximum representable value in E4M3 format).
Since E8M0 and FP32 share the same exponent bias (127), ``float_to_e8m0`` simply extracts
the 8-bit exponent from the FP32 representation, rounding up if the mantissa is non-zero.
3. The scaling factor is ``s_block = 2^(e - 127)``.
This ensures that the largest value in each block fits within the FP8 representable range without overflow.
.. raw:: html
:file: img/fp8_1d_scaling.svg
*Figure 1. MXFP8 uses one E8M0 scaling factor per 32 consecutive elements, providing fine-grained
quantization and compact scaling factor representation.*
Handling transposes
-------------------
Blackwell architecture supports multiple FP8 GEMM layouts (TN, NT, NN), so columnwise usage
does not require explicit transposition. However, rowwise and columnwise quantizations are different:
- *Rowwise* - 1 scaling factor per 32 consecutive elements along a row (1×32 blocks).
- *Columnwise* - 1 scaling factor per 32 consecutive elements along a column (32×1 blocks).
Since the scaling factor blocks have different orientations, rowwise and columnwise MXFP8 tensors
are numerically different — one cannot derive one from the other. Both must be quantized
independently from the full-precision data.
.. raw:: html
:file: img/mxfp8_row_col.svg
*Figure 2. MXFP8 rowwise vs columnwise quantization layout.*
Distributed training
--------------------
**Scale synchronization**
The blockwise scaled tensor does not need any scale synchronization among the nodes.
This is because each scaling factor is local to its 32-element block,
unlike :doc:`FP8 Current <../fp8_current_scaling/fp8_current_scaling>`/:doc:`Delayed Scaling <../fp8_delayed_scaling/fp8_delayed_scaling>` where a single global scale applies to the entire tensor, even when sharded.
**Quantized all-gather**
MXFP8 all-gather is supported.
Examples
--------
Here's how to use MXFP8 recipe in PyTorch and JAX:
.. tabs::
.. tab:: PyTorch
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Requires SM100 (Blackwell) or later
</div>
.. literalinclude:: pytorch_mxfp8_example.py
:language: python
:start-after: # START_MXFP8_EXAMPLE
:end-before: # END_MXFP8_EXAMPLE
.. tab:: JAX
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Requires SM100 (Blackwell) or later
</div>
.. literalinclude:: jax_mxfp8_example.py
:language: python
:start-after: # START_MXFP8_EXAMPLE
:end-before: # END_MXFP8_EXAMPLE
Supported devices
-----------------
SM 10.0, SM 10.3
----
Developer Notes
---------------
This section contains implementation details that may be useful for developers
but are not required for using MXFP8 in practice.
Swizzling scaling factors
^^^^^^^^^^^^^^^^^^^^^^^^^
Like :doc:`FP8 Blockwise Scaling <../fp8_blockwise_scaling/fp8_blockwise_scaling>`, MXFP8 uses different data layouts for communication and computation.
MXFP8 GEMMs require scaling factors in a specific hardware layout
(see `cuBLAS documentation <https://docs.nvidia.com/cuda/cublas/index.html#block-scaling-factors-layout>`__).
The conversion to this GEMM-ready layout is called *swizzling*. When no communication is needed,
swizzling can be fused with quantization. When communication is required, swizzled scaling factors
cannot be communicated across devices, so Transformer Engine performs swizzling after communication,
just before each GEMM operation.
.. raw:: html
:file: img/mxfp8_swizzle_both_tensors.svg
*Figure 3. MXFP8 swizzling process: standard scaling factors are rearranged into the hardware-required layout.*
Blackwell Tensor Cores compute matrix multiplications using ``128x128`` tiles.
Scaling factors are stored in row-major order, but to process a tile, we need a ``128x4`` vertical
slice of scaling factors. In row-major storage, these vertical slices are scattered in memory
with gaps between each row. The hardware requires them to be stored contiguously.
.. raw:: html
:file: img/mxfp8_tensor_scaling_layout.svg
*Figure 4. FP8 tensor (left) is divided into 128x128 tiles. Each tile requires a 128x4 block of scaling factors (right). These vertical blocks are not contiguous in memory.*
Swizzling transforms the layout to meet hardware requirements by:
1. **Linearizing** the ``128x4`` blocks so they are stored contiguously one after another.
2. **Permuting** the 4-byte elements within each block.
Specifically, if we index the 128 4-byte elements in a scaling factor block as :math:`0, 1, \dots, 127`, the hardware expects them in the following interleaved order:
.. code-block:: text
0, 32, 64, 96, 1, 33, 65, 97, ..., k, 32 + k, 64 + k, 96 + k, ..., 31, 63, 95, 127
.. raw:: html
:file: img/mxfp8_scale_linearize_and_swizzle.svg
*Figure 5. Linearization and swizzling of scaling factors. The 2D grid of scaling factors is first flattened into a contiguous sequence of blocks (top), then the rows within each block are interleaved to match the hardware access pattern (bottom).*
For columnwise scaling factors, the process is analogous but with ``4x128`` horizontal blocks instead of ``128x4`` vertical blocks.
All-gather of columnwise tensors
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
All-gather of columnwise tensors is supported and necessary because:
- columnwise quantized tensors cannot be computed from rowwise quantized ones,
- gathering high-precision tensors is avoided in most cases for performance reasons.
\ No newline at end of file
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Check for Blackwell or newer GPU
major, minor = torch.cuda.get_device_capability()
assert major >= 10, f"MXFP8 requires SM100 (Blackwell) or later, got SM{major}{minor}"
# START_MXFP8_EXAMPLE
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import MXFP8BlockScaling, Format
# Create MXFP8 recipe
recipe = MXFP8BlockScaling(
fp8_format=Format.E4M3, # E4M3 (default) or HYBRID; pure E5M2 not supported
)
# Create a linear layer with bfloat16 parameters
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
# Forward and backward pass
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
loss = output.sum()
loss.backward()
# END_MXFP8_EXAMPLE
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1100 140" width="1100" height="140">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 16px sans-serif; fill: #333; text-anchor: middle; }
.text { font: 13px sans-serif; fill: #333; text-anchor: middle; }
/* Arrows */
.arrow { stroke: #616161; stroke-width: 2; fill: none; marker-end: url(#arrowhead); }
/* High precision tensor */
.hp {
fill: #e8f5e9;
stroke: #43a047;
stroke-width: 2;
}
/* Amax operations */
.amax {
fill: #fff3e0;
stroke: #ff9800;
stroke-width: 2;
}
/* Quantize operations */
.quantize {
fill: #fce4ec;
stroke: #e91e63;
stroke-width: 2;
}
/* NVFP4 tensor */
.nvfp4 {
fill: #87CEEB;
stroke: #444;
stroke-width: 2;
}
/* All-gather operations */
.allgather {
fill: #e1f5fe;
stroke: #039be5;
stroke-width: 2;
}
</style>
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- Title -->
<text x="550" y="30" class="title">Quantization + All-Gather for NVFP4</text>
<!-- High Precision Tensor -->
<rect x="20" y="70" width="100" height="55" class="hp" rx="6"/>
<text x="70" y="93" class="text">High Precision</text>
<text x="70" y="110" class="text">Tensor</text>
<!-- Arrow -->
<path d="M 120 97 L 145 97" class="arrow"/>
<!-- Compute Amax -->
<rect x="145" y="70" width="100" height="55" class="amax" rx="6"/>
<text x="195" y="93" class="text">Compute</text>
<text x="195" y="110" class="text">Amax</text>
<!-- Arrow -->
<path d="M 245 97 L 270 97" class="arrow"/>
<!-- Synchronize Amax -->
<rect x="270" y="70" width="100" height="55" class="amax" rx="6"/>
<text x="320" y="93" class="text">Synchronize</text>
<text x="320" y="110" class="text">Amax</text>
<!-- Arrow -->
<path d="M 370 97 L 395 97" class="arrow"/>
<!-- Compute s_global -->
<rect x="395" y="70" width="100" height="55" class="amax" rx="6"/>
<text x="445" y="93" class="text">Compute</text>
<text x="445" y="110" class="text">s_global</text>
<!-- Arrow -->
<path d="M 495 97 L 520 97" class="arrow"/>
<!-- Scale + Cast -->
<rect x="520" y="70" width="100" height="55" class="quantize" rx="6"/>
<text x="570" y="86" class="text">Scale + Cast</text>
<text x="570" y="103" class="text">(s_block,</text>
<text x="570" y="118" class="text">s_global)</text>
<!-- Arrow -->
<path d="M 620 97 L 645 97" class="arrow"/>
<!-- NVFP4 Tensor (intermediate) -->
<rect x="645" y="70" width="100" height="55" class="nvfp4" rx="6"/>
<text x="695" y="93" class="text">NVFP4</text>
<text x="695" y="110" class="text">Tensor</text>
<!-- Arrow -->
<path d="M 745 97 L 770 97" class="arrow"/>
<!-- All-Gather -->
<rect x="770" y="70" width="100" height="55" class="allgather" rx="6"/>
<text x="820" y="102" class="text">All-Gather</text>
<!-- Arrow -->
<path d="M 870 97 L 895 97" class="arrow"/>
<!-- NVFP4 Gathered Tensor -->
<rect x="895" y="70" width="130" height="55" class="nvfp4" rx="6"/>
<text x="960" y="93" class="text">NVFP4 Gathered</text>
<text x="960" y="110" class="text">Tensor</text>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 60 500 450" width="100%" style="max-width: 500px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 18px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 14px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 12px sans-serif; fill: #555; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
/* Tensor colors */
.fp8-tensor { fill: #87CEEB; stroke: #444; stroke-width: 2; }
.fp8-block { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.fp8-block-alt { fill: #5F9FCC; stroke: #555; stroke-width: 1.5; }
/* Scaling factor colors */
.scale-factor { fill: #FFA500; stroke: #444; stroke-width: 2; }
.global-scale { fill: #FF6B6B; stroke: #444; stroke-width: 2; }
.grid-line { stroke: #444; stroke-width: 2; }
.boundary-line { stroke: #444; stroke-width: 2; }
</style>
</defs>
<!-- NVFP4 Scaling -->
<g id="nvfp4-scaling">
<text x="250" y="85" class="title">NVFP4 Hierarchical Scaling</text>
<text x="250" y="108" class="label" style="font-size: 12px;">(Block scaling + Global scaling)</text>
<!-- NVFP4 Tensor split into many small blocks (40×10) -->
<g id="tensor-blocks">
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="250.0" y="140.0" width="40" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="130.0" y="190.0" width="240" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Blocks ONLY where they don't overlap with white cross -->
<rect x="130" y="140" width="40" height="10" class="fp8-block"/>
<rect x="170" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="140" width="40" height="10" class="fp8-block"/>
<rect x="290" y="140" width="40" height="10" class="fp8-block"/>
<rect x="330" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="150" width="40" height="10" class="fp8-block"/>
<rect x="210" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="150" width="40" height="10" class="fp8-block"/>
<rect x="130" y="160" width="40" height="10" class="fp8-block"/>
<rect x="170" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="160" width="40" height="10" class="fp8-block"/>
<rect x="290" y="160" width="40" height="10" class="fp8-block"/>
<rect x="330" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="170" width="40" height="10" class="fp8-block"/>
<rect x="210" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="170" width="40" height="10" class="fp8-block"/>
<rect x="130" y="180" width="40" height="10" class="fp8-block"/>
<rect x="170" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="180" width="40" height="10" class="fp8-block"/>
<rect x="290" y="180" width="40" height="10" class="fp8-block"/>
<rect x="330" y="180" width="40" height="10" class="fp8-block-alt"/>
<!-- Three dots in VERTICAL white bar -->
<text x="270" y="167.5" class="dots-text"></text>
<text x="270" y="242.5" class="dots-text"></text>
<!-- Three dots in HORIZONTAL white bar -->
<text x="190" y="205" class="dots-text"></text>
<text x="330" y="205" class="dots-text"></text>
<!-- ONE diagonal dot at intersection -->
<text x="270" y="205" class="dots-text" transform="rotate(45 270 205)"></text>
<!-- Bottom rows (y >= 220 after horizontal white bar) -->
<rect x="130" y="220" width="40" height="10" class="fp8-block"/>
<rect x="170" y="220" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="220" width="40" height="10" class="fp8-block"/>
<rect x="290" y="220" width="40" height="10" class="fp8-block"/>
<rect x="330" y="220" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="230" width="40" height="10" class="fp8-block"/>
<rect x="210" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="230" width="40" height="10" class="fp8-block"/>
<rect x="130" y="240" width="40" height="10" class="fp8-block"/>
<rect x="170" y="240" width="40" height="10" class="fp8-block-alt"/>
<rect x="210" y="240" width="40" height="10" class="fp8-block"/>
<rect x="290" y="240" width="40" height="10" class="fp8-block"/>
<rect x="330" y="240" width="40" height="10" class="fp8-block-alt"/>
<rect x="130" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="170" y="250" width="40" height="10" class="fp8-block"/>
<rect x="210" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="290" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="330" y="250" width="40" height="10" class="fp8-block"/>
<!-- Main outline -->
<rect x="130.0" y="140.0" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Scaling factors tensor - 3+2 columns of 10px squares -->
<g id="scale-factors">
<!-- Orange background -->
<rect x="215" y="285" width="70" height="120" fill="#FFA500"/>
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="245" y="285" width="20" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="215" y="335" width="70" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Grid lines showing 10x10 squares (3 left + 2 right columns) -->
<!-- Vertical lines every 10px (skipping white space) -->
<!-- Left 3 columns -->
<line x1="225" y1="285" x2="225" y2="335" class="grid-line" stroke-width="1"/>
<line x1="235" y1="285" x2="235" y2="335" class="grid-line" stroke-width="1"/>
<line x1="245" y1="285" x2="245" y2="335" class="grid-line" stroke-width="1"/>
<!-- Right 2 columns -->
<line x1="265" y1="285" x2="265" y2="335" class="grid-line" stroke-width="1"/>
<line x1="275" y1="285" x2="275" y2="335" class="grid-line" stroke-width="1"/>
<line x1="285" y1="285" x2="285" y2="335" class="grid-line" stroke-width="1"/>
<!-- Bottom sections -->
<line x1="225" y1="365" x2="225" y2="405" class="grid-line" stroke-width="1"/>
<line x1="235" y1="365" x2="235" y2="405" class="grid-line" stroke-width="1"/>
<line x1="245" y1="365" x2="245" y2="405" class="grid-line" stroke-width="1"/>
<line x1="265" y1="365" x2="265" y2="405" class="grid-line" stroke-width="1"/>
<line x1="275" y1="365" x2="275" y2="405" class="grid-line" stroke-width="1"/>
<line x1="285" y1="365" x2="285" y2="405" class="grid-line" stroke-width="1"/>
<!-- Horizontal lines every 10px -->
<line x1="215" y1="295" x2="245" y2="295" class="grid-line" stroke-width="1"/>
<line x1="265" y1="295" x2="285" y2="295" class="grid-line" stroke-width="1"/>
<line x1="215" y1="305" x2="245" y2="305" class="grid-line" stroke-width="1"/>
<line x1="265" y1="305" x2="285" y2="305" class="grid-line" stroke-width="1"/>
<line x1="215" y1="315" x2="245" y2="315" class="grid-line" stroke-width="1"/>
<line x1="265" y1="315" x2="285" y2="315" class="grid-line" stroke-width="1"/>
<line x1="215" y1="325" x2="245" y2="325" class="grid-line" stroke-width="1"/>
<line x1="265" y1="325" x2="285" y2="325" class="grid-line" stroke-width="1"/>
<!-- Top bottom boundaries -->
<line x1="215" y1="335" x2="245" y2="335" class="grid-line" stroke-width="1"/>
<line x1="265" y1="335" x2="285" y2="335" class="grid-line" stroke-width="1"/>
<line x1="215" y1="365" x2="245" y2="365" class="grid-line" stroke-width="1"/>
<line x1="265" y1="365" x2="285" y2="365" class="grid-line" stroke-width="1"/>
<line x1="215" y1="375" x2="245" y2="375" class="grid-line" stroke-width="1"/>
<line x1="265" y1="375" x2="285" y2="375" class="grid-line" stroke-width="1"/>
<line x1="215" y1="385" x2="245" y2="385" class="grid-line" stroke-width="1"/>
<line x1="265" y1="385" x2="285" y2="385" class="grid-line" stroke-width="1"/>
<line x1="215" y1="395" x2="245" y2="395" class="grid-line" stroke-width="1"/>
<line x1="265" y1="395" x2="285" y2="395" class="grid-line" stroke-width="1"/>
<!-- Bottom boundaries -->
<line x1="215" y1="405" x2="245" y2="405" class="grid-line" stroke-width="1"/>
<line x1="265" y1="405" x2="285" y2="405" class="grid-line" stroke-width="1"/>
<!-- Main outline -->
<rect x="215" y="285" width="70" height="120" fill="none" stroke="#444" stroke-width="2"/>
<!-- Three dots -->
<text x="255" y="312.5" class="dots-text" style="font-size: 14px;"></text>
<text x="255" y="387.5" class="dots-text" style="font-size: 14px;"></text>
<text x="230" y="350" class="dots-text" style="font-size: 14px;"></text>
<text x="275" y="350" class="dots-text" style="font-size: 14px;"></text>
<text x="255" y="350" class="dots-text" style="font-size: 14px;" transform="rotate(45 255 350)"></text>
</g>
<text x="250" y="430" class="small-text" text-anchor="middle">E4M3 scaling factors (one per 16 elements)</text>
<!-- Global Scaling Factor -->
<g id="global-scale" transform="translate(350, 320)">
<rect x="10" y="10" width="20" height="20" class="global-scale"/>
<text x="20" y="60" class="small-text" text-anchor="middle">Global Scale (FP32)</text>
<text x="20" y="75" class="small-text" text-anchor="middle">(one per tensor)</text>
<text x="-20" y="25" class="dots-text" style="font-size: 20px;">+</text>
</g>
</g>
</svg>
\ No newline at end of file
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 650 520" width="100%" style="max-width: 650px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 18px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 12px sans-serif; fill: #555; text-anchor: middle; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
.fp8-block { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.fp8-block-alt { fill: #5F9FCC; stroke: #555; stroke-width: 1.5; }
.scale-block { fill: #FFA500; stroke: #555; stroke-width: 1.5; }
.global-scale { fill: #FF6B00; stroke: #444; stroke-width: 2; }
</style>
</defs>
<!-- ROWWISE SECTION -->
<text x="325" y="30" class="title">Rowwise (1×16 blocks)</text>
<!-- Rowwise Data [A, B] - 240x120 -->
<g id="rowwise-tensor">
<text x="160" y="55" class="small-text">Data [A, B]</text>
<!-- Top-left -->
<rect x="40" y="70" width="40" height="10" class="fp8-block"/>
<rect x="80" y="70" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="80" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="80" width="40" height="10" class="fp8-block"/>
<rect x="40" y="90" width="40" height="10" class="fp8-block"/>
<rect x="80" y="90" width="40" height="10" class="fp8-block-alt"/>
<!-- Top-right -->
<rect x="200" y="70" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="70" width="40" height="10" class="fp8-block"/>
<rect x="200" y="80" width="40" height="10" class="fp8-block"/>
<rect x="240" y="80" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="90" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="90" width="40" height="10" class="fp8-block"/>
<!-- Bottom-left -->
<rect x="40" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="160" width="40" height="10" class="fp8-block"/>
<rect x="40" y="170" width="40" height="10" class="fp8-block"/>
<rect x="80" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="80" y="180" width="40" height="10" class="fp8-block"/>
<!-- Bottom-right -->
<rect x="200" y="160" width="40" height="10" class="fp8-block"/>
<rect x="240" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="200" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="240" y="170" width="40" height="10" class="fp8-block"/>
<rect x="200" y="180" width="40" height="10" class="fp8-block"/>
<rect x="240" y="180" width="40" height="10" class="fp8-block-alt"/>
<text x="160" y="87" class="dots-text"></text>
<text x="160" y="177" class="dots-text"></text>
<text x="80" y="135" class="dots-text"></text>
<text x="240" y="135" class="dots-text"></text>
<rect x="40" y="70" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Rowwise s_block [A, B/16] - 90x120 -->
<g id="rowwise-scales">
<text x="425" y="55" class="small-text">s_block [A, B/16]</text>
<!-- Top-left -->
<rect x="380" y="70" width="10" height="10" class="scale-block"/>
<rect x="390" y="70" width="10" height="10" class="scale-block"/>
<rect x="380" y="80" width="10" height="10" class="scale-block"/>
<rect x="390" y="80" width="10" height="10" class="scale-block"/>
<rect x="380" y="90" width="10" height="10" class="scale-block"/>
<rect x="390" y="90" width="10" height="10" class="scale-block"/>
<!-- Top-right -->
<rect x="450" y="70" width="10" height="10" class="scale-block"/>
<rect x="460" y="70" width="10" height="10" class="scale-block"/>
<rect x="450" y="80" width="10" height="10" class="scale-block"/>
<rect x="460" y="80" width="10" height="10" class="scale-block"/>
<rect x="450" y="90" width="10" height="10" class="scale-block"/>
<rect x="460" y="90" width="10" height="10" class="scale-block"/>
<!-- Bottom-left -->
<rect x="380" y="160" width="10" height="10" class="scale-block"/>
<rect x="390" y="160" width="10" height="10" class="scale-block"/>
<rect x="380" y="170" width="10" height="10" class="scale-block"/>
<rect x="390" y="170" width="10" height="10" class="scale-block"/>
<rect x="380" y="180" width="10" height="10" class="scale-block"/>
<rect x="390" y="180" width="10" height="10" class="scale-block"/>
<!-- Bottom-right -->
<rect x="450" y="160" width="10" height="10" class="scale-block"/>
<rect x="460" y="160" width="10" height="10" class="scale-block"/>
<rect x="450" y="170" width="10" height="10" class="scale-block"/>
<rect x="460" y="170" width="10" height="10" class="scale-block"/>
<rect x="450" y="180" width="10" height="10" class="scale-block"/>
<rect x="460" y="180" width="10" height="10" class="scale-block"/>
<text x="425" y="87" class="dots-text"></text>
<text x="425" y="177" class="dots-text"></text>
<text x="390" y="135" class="dots-text"></text>
<text x="460" y="135" class="dots-text"></text>
<rect x="380" y="70" width="90" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Rowwise s_global -->
<g id="rowwise-global">
<text x="545" y="55" class="small-text">s_global</text>
<rect x="535" y="120" width="20" height="20" class="global-scale"/>
</g>
<!-- COLUMNWISE SECTION -->
<text x="325" y="230" class="title">Columnwise (16×1 blocks) — transposed storage</text>
<!-- Columnwise Data [B, A] - 120x240 (transposed) -->
<g id="colwise-tensor">
<text x="100" y="255" class="small-text">Data [B, A]</text>
<!-- Top-left blocks -->
<rect x="40" y="270" width="40" height="10" class="fp8-block"/>
<rect x="40" y="280" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="290" width="40" height="10" class="fp8-block"/>
<rect x="40" y="300" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="310" width="40" height="10" class="fp8-block"/>
<rect x="40" y="320" width="40" height="10" class="fp8-block-alt"/>
<!-- Top-right blocks -->
<rect x="120" y="270" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="280" width="40" height="10" class="fp8-block"/>
<rect x="120" y="290" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="300" width="40" height="10" class="fp8-block"/>
<rect x="120" y="310" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="320" width="40" height="10" class="fp8-block"/>
<!-- Bottom-left blocks (stick to bottom: box ends at y=510) -->
<rect x="40" y="450" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="460" width="40" height="10" class="fp8-block"/>
<rect x="40" y="470" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="480" width="40" height="10" class="fp8-block"/>
<rect x="40" y="490" width="40" height="10" class="fp8-block-alt"/>
<rect x="40" y="500" width="40" height="10" class="fp8-block"/>
<!-- Bottom-right blocks -->
<rect x="120" y="450" width="40" height="10" class="fp8-block"/>
<rect x="120" y="460" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="470" width="40" height="10" class="fp8-block"/>
<rect x="120" y="480" width="40" height="10" class="fp8-block-alt"/>
<rect x="120" y="490" width="40" height="10" class="fp8-block"/>
<rect x="120" y="500" width="40" height="10" class="fp8-block-alt"/>
<text x="100" y="307" class="dots-text"></text>
<text x="100" y="487" class="dots-text"></text>
<text x="60" y="395" class="dots-text"></text>
<text x="140" y="395" class="dots-text"></text>
<rect x="40" y="270" width="120" height="240" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Columnwise s_block [B, A/16] - 45x240, aligned with rowwise scales at x=380 -->
<g id="colwise-scales">
<text x="402" y="255" class="small-text">s_block [B, A/16]</text>
<!-- Top-left -->
<rect x="380" y="270" width="10" height="10" class="scale-block"/>
<rect x="380" y="280" width="10" height="10" class="scale-block"/>
<rect x="380" y="290" width="10" height="10" class="scale-block"/>
<rect x="380" y="300" width="10" height="10" class="scale-block"/>
<rect x="380" y="310" width="10" height="10" class="scale-block"/>
<rect x="380" y="320" width="10" height="10" class="scale-block"/>
<!-- Top-right -->
<rect x="415" y="270" width="10" height="10" class="scale-block"/>
<rect x="415" y="280" width="10" height="10" class="scale-block"/>
<rect x="415" y="290" width="10" height="10" class="scale-block"/>
<rect x="415" y="300" width="10" height="10" class="scale-block"/>
<rect x="415" y="310" width="10" height="10" class="scale-block"/>
<rect x="415" y="320" width="10" height="10" class="scale-block"/>
<!-- Bottom-left -->
<rect x="380" y="450" width="10" height="10" class="scale-block"/>
<rect x="380" y="460" width="10" height="10" class="scale-block"/>
<rect x="380" y="470" width="10" height="10" class="scale-block"/>
<rect x="380" y="480" width="10" height="10" class="scale-block"/>
<rect x="380" y="490" width="10" height="10" class="scale-block"/>
<rect x="380" y="500" width="10" height="10" class="scale-block"/>
<!-- Bottom-right -->
<rect x="415" y="450" width="10" height="10" class="scale-block"/>
<rect x="415" y="460" width="10" height="10" class="scale-block"/>
<rect x="415" y="470" width="10" height="10" class="scale-block"/>
<rect x="415" y="480" width="10" height="10" class="scale-block"/>
<rect x="415" y="490" width="10" height="10" class="scale-block"/>
<rect x="415" y="500" width="10" height="10" class="scale-block"/>
<text x="402" y="307" class="dots-text"></text>
<text x="402" y="487" class="dots-text"></text>
<text x="387" y="395" class="dots-text"></text>
<text x="420" y="395" class="dots-text"></text>
<rect x="380" y="270" width="45" height="240" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Columnwise s_global, aligned with rowwise at x=535 -->
<g id="colwise-global">
<text x="545" y="255" class="small-text">s_global</text>
<rect x="535" y="380" width="20" height="20" class="global-scale"/>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 900 220">
<defs>
<style>
.sign-bit { fill: #9db4d0; stroke: #333; stroke-width: 1; }
.exponent-bit { fill: #d9a066; stroke: #333; stroke-width: 1; }
.mantissa-bit { fill: #a8d99c; stroke: #333; stroke-width: 1; }
.bit-text { fill: #000; text-anchor: middle; dominant-baseline: middle; font-size: 16px; }
.header-text { fill: #555; font-weight: normal; text-anchor: middle; font-size: 18px; }
.value-text { fill: #333; font-size: 18px; }
.format-label { fill: #333; font-weight: bold; text-anchor: middle; dominant-baseline: middle; font-size: 20px; }
</style>
</defs>
<!-- FP8 E4M3 Format (8 bits: 1 + 4 + 3) -->
<text x="60" y="60" class="format-label">FP8 E4M3</text>
<!-- Sign bit (1) -->
<rect x="140" y="45" width="18" height="30" class="sign-bit"/>
<text x="149" y="60" class="bit-text">0</text>
<!-- Exponent bits (4) -->
<rect x="163" y="45" width="18" height="30" class="exponent-bit"/>
<text x="172" y="60" class="bit-text">1</text>
<rect x="186" y="45" width="18" height="30" class="exponent-bit"/>
<text x="195" y="60" class="bit-text">0</text>
<rect x="209" y="45" width="18" height="30" class="exponent-bit"/>
<text x="218" y="60" class="bit-text">0</text>
<rect x="232" y="45" width="18" height="30" class="exponent-bit"/>
<text x="241" y="60" class="bit-text">0</text>
<!-- Mantissa bits (3) -->
<rect x="255" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="264" y="60" class="bit-text">1</text>
<rect x="278" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="287" y="60" class="bit-text">1</text>
<rect x="301" y="45" width="18" height="30" class="mantissa-bit"/>
<text x="310" y="60" class="bit-text">1</text>
<text x="355" y="60" class="value-text">(1 sign, 4 exp, 3 mantissa)</text>
<!-- FP8 E5M2 Format (8 bits: 1 + 5 + 2) -->
<text x="60" y="120" class="format-label">FP8 E5M2</text>
<!-- Sign bit (1) -->
<rect x="140" y="105" width="18" height="30" class="sign-bit"/>
<text x="149" y="120" class="bit-text">0</text>
<!-- Exponent bits (5) -->
<rect x="163" y="105" width="18" height="30" class="exponent-bit"/>
<text x="172" y="120" class="bit-text">1</text>
<rect x="186" y="105" width="18" height="30" class="exponent-bit"/>
<text x="195" y="120" class="bit-text">0</text>
<rect x="209" y="105" width="18" height="30" class="exponent-bit"/>
<text x="218" y="120" class="bit-text">0</text>
<rect x="232" y="105" width="18" height="30" class="exponent-bit"/>
<text x="241" y="120" class="bit-text">0</text>
<rect x="255" y="105" width="18" height="30" class="exponent-bit"/>
<text x="264" y="120" class="bit-text">0</text>
<!-- Mantissa bits (2) -->
<rect x="278" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="287" y="120" class="bit-text">1</text>
<rect x="301" y="105" width="18" height="30" class="mantissa-bit"/>
<text x="310" y="120" class="bit-text">1</text>
<text x="355" y="120" class="value-text">(1 sign, 5 exp, 2 mantissa)</text>
<!-- NVFP4 E2M1 Format (4 bits: 1 + 2 + 1) -->
<text x="60" y="180" class="format-label">NVFP4</text>
<!-- Sign bit (1) -->
<rect x="140" y="165" width="18" height="30" class="sign-bit"/>
<text x="149" y="180" class="bit-text">0</text>
<!-- Exponent bits (2) -->
<rect x="163" y="165" width="18" height="30" class="exponent-bit"/>
<text x="172" y="180" class="bit-text">1</text>
<rect x="186" y="165" width="18" height="30" class="exponent-bit"/>
<text x="195" y="180" class="bit-text">0</text>
<!-- Mantissa bits (1) -->
<rect x="209" y="165" width="18" height="30" class="mantissa-bit"/>
<text x="218" y="180" class="bit-text">1</text>
<text x="355" y="180" class="value-text">(1 sign, 2 exp, 1 mantissa)</text>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 900 340">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
/* Diagram-specific styles */
.input { fill: #e3f2fd; stroke: #1976d2; stroke-width: 2; }
.grad { fill: #fce4ec; stroke: #d81b60; stroke-width: 2; }
.rht { fill: #fff3e0; stroke: #f57c00; stroke-width: 2; }
.output { fill: #e8f5e9; stroke: #388e3c; stroke-width: 2; }
.divider { stroke: #bdbdbd; stroke-width: 2; stroke-dasharray: 6,4; }
/* Arrow override */
.arrow { marker-end: url(#arrowhead); }
</style>
<!-- Arrow marker -->
<marker id="arrowhead" markerWidth="8" markerHeight="8" refX="7" refY="3" orient="auto">
<polygon points="0 0, 8 3, 0 6" fill="#616161" />
</marker>
</defs>
<!-- Title -->
<text x="450" y="30" class="title">Random Hadamard Transform for WGRAD GEMM</text>
<!-- Divider -->
<line x1="450" y1="50" x2="450" y2="320" class="divider"/>
<!-- ═══════════════ LEFT SIDE: Without RHT ═══════════════ -->
<g id="without-rht">
<text x="225" y="70" class="section-title">Without RHT</text>
<!-- Top row: Activations → Quantize → GEMM -->
<!-- Activations -->
<rect x="40" y="100" width="90" height="45" rx="6" class="input"/>
<text x="85" y="127" class="text">Activations</text>
<!-- Arrow -->
<path d="M 130 122 L 175 122" class="arrow"/>
<!-- Quantize -->
<rect x="175" y="100" width="80" height="45" rx="6" class="quantize"/>
<text x="215" y="127" class="text">Quantize</text>
<!-- Arrow to GEMM -->
<path d="M 255 122 L 300 122" class="arrow"/>
<!-- GEMM -->
<rect x="300" y="80" width="85" height="90" rx="6" class="gemm"/>
<text x="342" y="118" class="text">WGRAD</text>
<text x="342" y="138" class="text">GEMM</text>
<!-- Bottom row: Output Grad → Quantize → GEMM -->
<!-- Output Grad -->
<rect x="40" y="170" width="90" height="45" rx="6" class="grad"/>
<text x="85" y="197" class="text">Output Grad</text>
<!-- Arrow -->
<path d="M 130 192 L 175 192" class="arrow"/>
<!-- Quantize 2 -->
<rect x="175" y="170" width="80" height="45" rx="6" class="quantize"/>
<text x="215" y="197" class="text">Quantize</text>
<!-- Arrow to GEMM (diagonal) -->
<path d="M 255 192 L 300 155" class="arrow"/>
<!-- Arrow from GEMM to output -->
<path d="M 342 170 L 342 245" class="arrow"/>
<!-- Weight Grad -->
<rect x="300" y="245" width="85" height="45" rx="6" class="output"/>
<text x="342" y="272" class="text">Weight Grad</text>
</g>
<!-- ═══════════════ RIGHT SIDE: With RHT ═══════════════ -->
<g id="with-rht">
<text x="675" y="70" class="section-title">With RHT</text>
<!-- Top row: Activations → RHT → Quantize → GEMM -->
<!-- Activations -->
<rect x="460" y="100" width="90" height="45" rx="6" class="input"/>
<text x="505" y="127" class="text">Activations</text>
<!-- Arrow -->
<path d="M 550 122 L 575 122" class="arrow"/>
<!-- RHT -->
<rect x="575" y="100" width="50" height="45" rx="6" class="rht"/>
<text x="600" y="127" class="text">RHT</text>
<!-- Arrow -->
<path d="M 625 122 L 650 122" class="arrow"/>
<!-- Quantize -->
<rect x="650" y="100" width="80" height="45" rx="6" class="quantize"/>
<text x="690" y="127" class="text">Quantize</text>
<!-- Arrow to GEMM -->
<path d="M 730 122 L 775 122" class="arrow"/>
<!-- GEMM -->
<rect x="775" y="80" width="85" height="90" rx="6" class="gemm"/>
<text x="817" y="118" class="text">WGRAD</text>
<text x="817" y="138" class="text">GEMM</text>
<!-- Bottom row: Output Grad → RHT → Quantize → GEMM -->
<!-- Output Grad -->
<rect x="460" y="170" width="90" height="45" rx="6" class="grad"/>
<text x="505" y="197" class="text">Output Grad</text>
<!-- Arrow -->
<path d="M 550 192 L 575 192" class="arrow"/>
<!-- RHT 2 -->
<rect x="575" y="170" width="50" height="45" rx="6" class="rht"/>
<text x="600" y="197" class="text">RHT</text>
<!-- Arrow -->
<path d="M 625 192 L 650 192" class="arrow"/>
<!-- Quantize 2 -->
<rect x="650" y="170" width="80" height="45" rx="6" class="quantize"/>
<text x="690" y="197" class="text">Quantize</text>
<!-- Arrow to GEMM (diagonal) -->
<path d="M 730 192 L 775 155" class="arrow"/>
<!-- Arrow from GEMM to output -->
<path d="M 817 170 L 817 245" class="arrow"/>
<!-- Weight Grad -->
<rect x="775" y="245" width="85" height="45" rx="6" class="output"/>
<text x="817" y="272" class="text">Weight Grad</text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1000 400">
<defs>
<style>
.axis { stroke: #333; stroke-width: 3; stroke-linecap: round; }
.tick { stroke: #333; stroke-width: 2; }
.tick-label { font-family: sans-serif; font-size: 22px; text-anchor: middle; fill: #333; font-weight: bold; }
.title { font-family: sans-serif; font-size: 24px; font-weight: bold; text-anchor: middle; fill: #333; }
.sub-label { font-family: sans-serif; font-size: 16px; text-anchor: middle; fill: #555; }
.value-point { fill: #e74c3c; stroke: #333; stroke-width: 2; }
.value-label { font-family: sans-serif; font-size: 22px; text-anchor: middle; fill: #e74c3c; font-weight: bold; }
.divider { stroke: #ccc; stroke-width: 2; stroke-dasharray: 5,5; }
.bar-bg { fill: #eee; stroke: #555; stroke-width: 1.5; }
.bar-fill-blue { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.bar-fill-green { fill: #a8d99c; stroke: #555; stroke-width: 1.5; }
.percentage-text { font-family: sans-serif; font-weight: bold; font-size: 20px; }
</style>
</defs>
<!-- Divider -->
<line x1="500" y1="20" x2="500" y2="380" class="divider" />
<!-- LEFT SIDE: Deterministic Rounding -->
<g transform="translate(0,0)">
<text x="250" y="40" class="title">Round to Nearest</text>
<!-- Axis -->
<line x1="50" y1="150" x2="450" y2="150" class="axis" />
<!-- v1 -->
<line x1="100" y1="140" x2="100" y2="160" class="tick" />
<text x="100" y="185" class="tick-label">v₁</text>
<!-- v2 -->
<line x1="400" y1="140" x2="400" y2="160" class="tick" />
<text x="400" y="185" class="tick-label">v₂</text>
<!-- x (at 40% distance) -->
<circle cx="220" cy="150" r="8" class="value-point" />
<text x="220" y="120" class="value-label">x</text>
<!-- Visuals for deterministic rounding: Bars -->
<g transform="translate(50, 230)">
<!-- Bar for v1 (100%) -->
<text x="50" y="-10" class="sub-label">Round to v₁</text>
<rect x="10" y="0" width="80" height="100" rx="4" class="bar-bg" />
<rect x="10" y="0" width="80" height="100" rx="4" class="bar-fill-green" />
<text x="50" y="56" fill="#000" text-anchor="middle" class="percentage-text">100%</text>
<!-- Bar for v2 (0%) -->
<text x="350" y="-10" class="sub-label">Round to v₂</text>
<rect x="310" y="0" width="80" height="100" rx="4" class="bar-bg" />
<!-- 0% filled, so just bg visible -->
<text x="350" y="56" fill="#666" text-anchor="middle" class="percentage-text">0%</text>
</g>
</g>
<!-- RIGHT SIDE: Stochastic -->
<g transform="translate(500,0)">
<text x="250" y="40" class="title">Stochastic Rounding</text>
<!-- Axis -->
<line x1="50" y1="150" x2="450" y2="150" class="axis" />
<!-- v1 -->
<line x1="100" y1="140" x2="100" y2="160" class="tick" />
<text x="100" y="185" class="tick-label">v₁</text>
<!-- v2 -->
<line x1="400" y1="140" x2="400" y2="160" class="tick" />
<text x="400" y="185" class="tick-label">v₂</text>
<!-- x (at 40% distance) -->
<circle cx="220" cy="150" r="8" class="value-point" />
<text x="220" y="120" class="value-label">x</text>
<!-- Visuals for Stochastic: Bars -->
<g transform="translate(50, 230)">
<!-- Bar for v1 -->
<text x="50" y="-10" class="sub-label">Round to v₁</text>
<rect x="10" y="0" width="80" height="100" rx="4" class="bar-bg" />
<rect x="10" y="40" width="80" height="60" rx="4" class="bar-fill-blue" />
<text x="50" y="80" fill="#000" text-anchor="middle" class="percentage-text">60%</text>
<!-- Bar for v2 -->
<text x="350" y="-10" class="sub-label">Round to v₂</text>
<rect x="310" y="0" width="80" height="100" rx="4" class="bar-bg" />
<rect x="310" y="60" width="80" height="40" rx="4" class="bar-fill-blue" />
<text x="350" y="90" fill="#000" text-anchor="middle" class="percentage-text">40%</text>
</g>
</g>
</svg>
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Check for Blackwell or newer GPU
from transformer_engine.jax.quantize import get_device_compute_capability
assert (
get_device_compute_capability() >= 100
), f"NVFP4 requires SM100 (Blackwell) or later, got SM{get_device_compute_capability()}"
# START_NVFP4_EXAMPLE
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import DenseGeneral
from transformer_engine.common.recipe import NVFP4BlockScaling
# Define NVFP4 recipe
# 2D weight quantization and RHT are enabled by default
recipe = NVFP4BlockScaling()
# To disable features, use:
# recipe = NVFP4BlockScaling(disable_rht=True, disable_2d_quantization=True)
with te.autocast(enabled=True, recipe=recipe):
# Initialize layer and data
layer = DenseGeneral(features=1024)
key, sr_key = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
# NVFP4 requires sr_rng for stochastic rounding
rngs = {"sr_rng": sr_key}
var_collect = layer.init({"params": key, "sr_rng": sr_key}, x)
# Forward and backward pass
def loss_fn(var_collect):
output = layer.apply(var_collect, x, rngs=rngs)
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
# END_NVFP4_EXAMPLE
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
NVFP4
===================================
NVFP4 is the first 4-bit recipe introduced in Transformer Engine –
please refer to the `NVFP4 paper <https://arxiv.org/abs/2509.25149>`__ for more details.
It is a more complex recipe than the previous ones – apart from the new data format,
it introduces multiple features which help training stability.
Data Format
----------------------
The NVFP4 datatype consists of 1 sign bit, 2 exponent bits, and 1 mantissa bit (E2M1).
It can represent values of magnitude up to +/- 6.
NVFP4 uses a hierarchical block scaling approach where multiple scaling factors are combined to recover the high precision value.
.. raw:: html
:file: img/nvfp4_vs_fp8.svg
*Figure 1. Bit layout comparison between standard FP8 formats (E4M3 and E5M2) and NVFP4 (E2M1).*
The representation of an NVFP4 tensor element ``x`` is given by:
.. code-block:: python
x = x_e2m1 * s_block * s_global
where
* ``x_e2m1`` is the 4-bit value,
* ``s_block`` is a local **FP8 E4M3** scaling factor shared by a block of 16 consecutive elements,
* ``s_global`` is a global **FP32** scaling factor applied to the entire tensor.
**Scaling Factor Computation**
The scaling factors are computed as follows:
1. Global scaling factor (``s_global``):
.. code-block:: python
s_global = global_amax / (fp8_max * fp4_max)
# where:
# - global_amax: maximum absolute value across the entire tensor
# - fp8_max: maximum representable value in FP8 E4M3 (448.0)
# - fp4_max: maximum representable value in NVFP4 E2M1 (6.0)
2. Block scaling factor (``s_block``):
.. code-block:: python
s_block = (block_amax / fp4_max) / s_global
# where:
# - block_amax: maximum absolute value within the block
# - fp4_max: maximum representable value in NVFP4 E2M1 (6.0)
# - s_block is stored in FP8 E4M3 format
.. raw:: html
:file: img/nvfp4_hierarchical_scaling.svg
*Figure 2. NVFP4 hierarchical scaling structure showing the combination of block-level and global scaling factors.*
This hierarchical structure uses fine-grained block scaling to handle the tensor's dynamic range,
while the FP4 values represent the block-level dynamic range. The global scaling factor
aligns values to the representable range of the E4M3 × E2M1 combination.
**2D weight scaling**
NVFP4 can be:
* 1 dimensional - each block of 16 consecutive elements shares a scaling factor,
* 2 dimensional - each block of 16x16 elements shares a scaling factor.
By default, NVFP4 uses 2D scaling for weights and 1D scaling for activations and gradients.
Set ``disable_2d_quantization=True`` in the recipe configuration to force 1D scaling for weights as well (activations and gradients always use 1D).
The motivation for using 2D scaling for weights is to ensure that rowwise and columnwise
quantized tensors are numerically equivalent.
Please refer to the `NVFP4 paper <https://arxiv.org/abs/2509.25149>`__ for more details.
Stochastic Rounding
-------------------
Stochastic rounding is applied when casting scaled values to NVFP4 format. Instead of deterministic rounding
(always rounding to nearest even value), each scaled value is probabilistically rounded to one of the two
nearest representable NVFP4 values. The probability of rounding to a given value is inversely proportional to
the distance to that value, which ensures that the expected value of the quantized
tensor equals the original value, eliminating systematic quantization bias during training.
Stochastic rounding is hardware-accelerated using native GPU instructions introduced with the
Blackwell architecture.
.. raw:: html
:file: img/stochastic_rounding.svg
*Figure 3. Stochastic rounding illustration. Given a value* ``x`` *to be quantized, and the two nearest
representable NVFP4 values* ``v1`` *(lower) and* ``v2`` *(higher), deterministic rounding always
rounds to the nearest value, while stochastic rounding probabilistically rounds to either value.
If* ``x`` *is 40% of the way from* ``v1`` *to* ``v2``, *there is a 60% chance of rounding to* ``v1``
*and a 40% chance of rounding to* ``v2``.
Stochastic rounding is enabled only for gradients. It can be disabled by setting
``disable_stochastic_rounding=True`` in the recipe configuration.
Random Hadamard Transform
--------------------------
Random Hadamard Transform (RHT) applies an orthogonal rotation to the tensor **before quantization**,
smoothing outliers in the tensor distributions and making them easier to represent accurately in NVFP4.
RHT is applied to columnwise quantization of inputs and gradients, which are operands
for the **wgrad GEMM**. This GEMM is particularly sensitive
to quantization errors, hence the additional outlier smoothing.
RHT is supported only for BF16 inputs/gradients.
The transform is defined as:
.. math::
x' = x H
where :math:`H` is the RHT matrix defined below. The quantization scale factor is computed
from the rotated tensor :math:`x'`.
**Hadamard matrix**
The :math:`d \times d` Hadamard matrix has elements :math:`\pm 1` and satisfies :math:`H_d H_d^T = d I`.
When normalized by :math:`1/\sqrt{d}`, the matrix becomes orthogonal and can be applied
to both operands of a matrix multiplication:
.. math::
C = (AH)(H^T B) = AB
where the transforms cancel within the dot-product since :math:`H H^T = I`.
**Sign matrix**
In the RHT implementation, a :math:`d`-dimensional diagonal sign matrix :math:`S_d` is applied
together with the Hadamard matrix:
.. math::
H = \frac{1}{\sqrt{d}} S_d H_d
where diagonal entries of :math:`S_d` are :math:`\{-1, 1\}` and flip the signs of different rows of :math:`H_d`.
As described in the paper, a single random sign vector is shared across all linear layers throughout training.
In the implementation, this vector is fixed and the RHT matrix is computed once at initialization and cached.
**Tiled implementation**
The Hadamard transform is performed in a tiled approach along the last dimension of the tensor.
For an :math:`m \times k` tensor, the data is reshaped to :math:`(mk/d) \times d`
and multiplied by the :math:`d \times d` matrix :math:`H`. In this implementation, :math:`d = 16`.
.. raw:: html
:file: img/rht.svg
*Figure 4. WGRAD GEMM pipeline comparison: without RHT (left) and with RHT applied (right).*
Handling transposes
-------------------
Like :doc:`MXFP8 <../mxfp8/mxfp8>`, NVFP4 requires both rowwise and columnwise quantized tensors
for different GEMM operands. Unlike MXFP8 which supports multiple layouts (TN, NT, NN),
**NVFP4 GEMM only supports the TN layout**.
NVFP4 stores columnwise data and scaling factors in a **transposed layout**:
- **Rowwise**: data ``[A, B]`` with 1×16 horizontal blocks, ``scales`` shape ``[A, B/16]``
- **Columnwise**: data ``[B, A]`` (transposed) with 1×16 horizontal blocks, ``scales`` shape ``[B, A/16]``
Scale tensors are padded for hardware alignment: first dimension to a multiple of 128,
second dimension to a multiple of 4 (e.g. rowwise: ``[roundup(A, 128), roundup(B/16, 4)]``).
.. raw:: html
:file: img/nvfp4_row_col.svg
*Figure 5. NVFP4 rowwise vs columnwise quantization layout. Unlike MXFP8, columnwise scales are stored transposed.*
Distributed training
--------------------
**Amax reduction**
Block scaling factors (``s_block``) do not require synchronization between nodes,
as each scaling factor is local to its block of 16 elements.
However, the global scaling factor (``s_global``) requires amax synchronization for gathered tensors.
For tensors that are gathered (e.g., input and gradient in sequence parallelism),
amax reduction is performed before quantization.
If before synchronization there was ``amax_1`` on node 1,
``amax_2`` on node 2, etc., after synchronization there will be ``max(amax_1, amax_2, ...)`` on all nodes.
**Quantized all-gather**
NVFP4 all-gather is supported.
.. raw:: html
:file: img/nvfp4_all_gather.svg
*Figure 6. Quantization and all-gather flow for NVFP4 showing amax synchronization and hierarchical scaling.*
Examples
--------
Here's how to use NVFP4 recipe in PyTorch and JAX. The examples show how to configure features like 2D weight quantization and Random Hadamard Transform (RHT):
.. tabs::
.. tab:: PyTorch
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Requires SM100 (Blackwell) or later
</div>
.. literalinclude:: pytorch_nvfp4_example.py
:language: python
:start-after: # START_NVFP4_EXAMPLE
:end-before: # END_NVFP4_EXAMPLE
.. tab:: JAX
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Requires SM100 (Blackwell) or later
</div>
.. literalinclude:: jax_nvfp4_example.py
:language: python
:start-after: # START_NVFP4_EXAMPLE
:end-before: # END_NVFP4_EXAMPLE
Supported devices
-----------------
* **Training**: SM 10.0, SM 10.3
* **Inference**: SM 10.0+
----
Developer Notes
---------------
This section contains implementation details that may be useful for developers
but are not required for using NVFP4 in practice.
Swizzling scaling factors
^^^^^^^^^^^^^^^^^^^^^^^^^
NVFP4 requires swizzling of block scaling factors (``s_block``) before GEMM operations,
similar to :doc:`MXFP8 <../mxfp8/mxfp8>`. Key differences:
- Block size is 16 (vs 32 for MXFP8)
- Both rowwise and columnwise scaling factors are swizzled, but thanks to the transposed
columnwise layout, a single rowwise swizzle kernel handles both cases.
- Scaling factors are stored as FP8 E4M3 (vs E8M0 for MXFP8)
All-gather of columnwise tensors
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
All-gather of columnwise tensors is supported. To enable quantized all-gather,
all nodes must use the same ``s_global``, which is computed from the synchronized global amax.
This is automatically enabled for column-parallel and row-parallel linear layers.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Check for Blackwell or newer GPU
major, minor = torch.cuda.get_device_capability()
assert major >= 10, f"NVFP4 requires SM100 (Blackwell) or later, got SM{major}{minor}"
# START_NVFP4_EXAMPLE
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import NVFP4BlockScaling
# Define NVFP4 recipe
# 2D weight quantization and RHT are enabled by default
recipe = NVFP4BlockScaling()
# To disable features, use:
# recipe = NVFP4BlockScaling(disable_rht=True, disable_2d_quantization=True)
# Create a linear layer with bfloat16 parameters
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
# Forward and backward pass
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
loss = output.sum()
loss.backward()
# END_NVFP4_EXAMPLE
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+
# START_FUSED_LAYERS
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import LayerNorm, DenseGeneral, LayerNormDenseGeneral
from transformer_engine.common.recipe import DelayedScaling
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
# Example 1: Separate LayerNorm and DenseGeneral layers
layer_norm = LayerNorm()
dense = DenseGeneral(features=1024)
# Initialize parameters
ln_params = layer_norm.init(key, x)
dense_params = dense.init(key, x)
# Two separate operations
normalized = layer_norm.apply(ln_params, x)
output_separate = dense.apply(dense_params, normalized)
# Example 2: Fused LayerNormDenseGeneral layer
fused_layer = LayerNormDenseGeneral(features=1024)
# Initialize and apply with FP8 autocast
recipe = DelayedScaling()
with te.autocast(enabled=True, recipe=recipe):
fused_params = fused_layer.init(key, x)
output_fused, _ = fused_layer.apply(fused_params, x) # Returns (output, ln_output)
# The fused layer is more efficient as it combines LayerNorm and quantization
# END_FUSED_LAYERS
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment