Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
recursive-include transformer_engine/common/include *.* recursive-include transformer_engine/common/include *.*
recursive-include build_tools *.py *.txt
...@@ -137,7 +137,7 @@ Flax ...@@ -137,7 +137,7 @@ Flax
for _ in range(10): for _ in range(10):
loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp) loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)
For a more comprehensive tutorial, check out our `Quickstart Notebook <https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb>`_. For a more comprehensive tutorial, check out our `Getting Started Guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/getting_started.html>`_.
.. overview-end-marker-do-not-remove .. overview-end-marker-do-not-remove
...@@ -175,15 +175,22 @@ For example to use the NGC PyTorch container interactively, ...@@ -175,15 +175,22 @@ For example to use the NGC PyTorch container interactively,
.. code-block:: bash .. code-block:: bash
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.08-py3 docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:26.01-py3
For example to use the NGC JAX container interactively, For example to use the NGC JAX container interactively,
.. code-block:: bash .. code-block:: bash
docker run --gpus all -it --rm nvcr.io/nvidia/jax:25.08-py3 docker run --gpus all -it --rm nvcr.io/nvidia/jax:26.01-py3
Where 25.08 (corresponding to August 2025 release) is the container version. Where 26.01 (corresponding to January 2026 release) is the container version.
We recommend updating to the latest NGC container available here:
* https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch
* https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax
If you run any examples, please ensure you are using a matching version of TransformerEngine. TransformerEngine is pre-built and packaged inside the containers with examples available at ``/opt/transformerengine`` or ``/opt/transformer-engine``. If you would like to use examples from TE main branch and are running into import errors, please try the latest pip package or building from source, although NGC containers are recommended for ease-of-use for most users.
**Benefits of using NGC containers:** **Benefits of using NGC containers:**
...@@ -308,6 +315,37 @@ Troubleshooting ...@@ -308,6 +315,37 @@ Troubleshooting
cd transformer_engine cd transformer_engine
pip install -v -v -v --no-build-isolation . pip install -v -v -v --no-build-isolation .
**Problems using UV or Virtual Environments:**
1. **Import Error:**
* **Symptoms:** Cannot import ``transformer_engine``
* **Solution:** Ensure your UV environment is active and that you have used ``uv pip install --no-build-isolation <te_pypi_package_or_wheel_or_source_dir>`` instead of a regular pip install to your system environment.
2. **cuDNN Sublibrary Loading Failed:**
* **Symptoms:** Errors at runtime with ``CUDNN_STATUS_SUBLIBRARY_LOADING_FAILED``
* **Solution:** This can occur when TE is built against the container's system installation of cuDNN, but pip packages inside the virtual environment pull in pip packages for ``nvidia-cudnn-cu12/cu13``. To resolve this, when building TE from source please specify the following environment variables to point to the cuDNN in your virtual environment.
.. code-block:: bash
export CUDNN_PATH=$(pwd)/.venv/lib/python3.12/site-packages/nvidia/cudnn
export CUDNN_HOME=$CUDNN_PATH
export LD_LIBRARY_PATH=$CUDNN_PATH/lib:$LD_LIBRARY_PATH
3. **Building Wheels:**
* **Symptoms:** Regular TE installs work correctly but UV wheel builds fail at runtime.
* **Solution:** Ensure that ``uv build --wheel --no-build-isolation -v`` is used during the wheel build as well as the pip installation of the wheel. Use ``-v`` for verbose output to verify that TE is not pulling in a mismatching version of PyTorch or JAX that differs from the UV environment's version.
**JAX-specific Common Issues and Solutions:**
1. **FFI Issues:**
* **Symptoms:** ``No registered implementation for custom call to <some_te_ffi> for platform CUDA``
* **Solution:** Ensure ``--no-build-isolation`` is used during installation. If pre-building wheels, ensure that the wheel is both built and installed with ``--no-build-isolation``. See "Problems using UV or Virtual Environments" above if using UV.
.. troubleshooting-end-marker-do-not-remove .. troubleshooting-end-marker-do-not-remove
Breaking Changes Breaking Changes
......
...@@ -266,9 +266,10 @@ def nvcc_path() -> Tuple[str, str]: ...@@ -266,9 +266,10 @@ def nvcc_path() -> Tuple[str, str]:
def get_cuda_include_dirs() -> Tuple[str, str]: def get_cuda_include_dirs() -> Tuple[str, str]:
"""Returns the CUDA header directory.""" """Returns the CUDA header directory."""
force_wheels = bool(int(os.getenv("NVTE_BUILD_USE_NVIDIA_WHEELS", "0")))
# If cuda is installed via toolkit, all necessary headers # If cuda is installed via toolkit, all necessary headers
# are bundled inside the top level cuda directory. # are bundled inside the top level cuda directory.
if cuda_toolkit_include_path() is not None: if not force_wheels and cuda_toolkit_include_path() is not None:
return [cuda_toolkit_include_path()] return [cuda_toolkit_include_path()]
# Use pip wheels to include all headers. # Use pip wheels to include all headers.
...@@ -277,7 +278,10 @@ def get_cuda_include_dirs() -> Tuple[str, str]: ...@@ -277,7 +278,10 @@ def get_cuda_include_dirs() -> Tuple[str, str]:
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
raise RuntimeError("CUDA not found.") raise RuntimeError("CUDA not found.")
cuda_root = Path(nvidia.__file__).parent if nvidia.__file__ is not None:
cuda_root = Path(nvidia.__file__).parent
else:
cuda_root = Path(nvidia.__path__[0]) # namespace
return [ return [
subdir / "include" subdir / "include"
for subdir in cuda_root.iterdir() for subdir in cuda_root.iterdir()
......
/* Diagram color definitions for Transformer Engine documentation */
/* High precision (BF16/FP16) elements */
.hp {
fill: #ede7f6;
stroke: #673ab7;
stroke-width: 2;
}
/* FP8 precision elements */
.fp8 {
fill: #fff8e1;
stroke: #ffa726;
stroke-width: 2;
}
/* GEMM/computation operations */
.gemm {
fill: #ffe0b2;
stroke: #fb8c00;
stroke-width: 2.5;
}
/* Quantization operations */
.quantize {
fill: #e8f5e9;
stroke: #66bb6a;
stroke-width: 2;
}
/* Amax computation operations */
.amax {
fill: #e1f5fe;
stroke: #039be5;
stroke-width: 2;
}
/* Text styles */
.text {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 14px;
text-anchor: middle;
fill: #212121;
}
.small-text {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 14px;
text-anchor: middle;
fill: #757575;
}
.label {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 14px;
text-anchor: middle;
fill: #424242;
}
.title {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 18px;
font-weight: 600;
text-anchor: middle;
fill: #212121;
}
.section-title {
font-family: 'Segoe UI', Arial, sans-serif;
font-size: 15px;
font-weight: 600;
text-anchor: middle;
}
/* Arrows */
/* Note: marker-end references #arrowhead marker which must be defined in each SVG's <defs> section */
.arrow {
stroke: #616161;
stroke-width: 2;
fill: none;
marker-end: url(#arrowhead);
}
/* Additional box and element styles */
.box-blue {
fill: #e3f2fd;
stroke: #1976d2;
stroke-width: 2;
}
.box-orange {
fill: #fff3e0;
stroke: #f57c00;
stroke-width: 2;
}
.box-green {
fill: #c8e6c9;
stroke: #388e3c;
stroke-width: 2;
}
.box-dashed {
stroke-dasharray: 5,5;
}
/* LayerNorm specific */
.layernorm {
fill: #b3e5fc;
stroke: #0277bd;
stroke-width: 2.5;
}
/* Fused layers */
.fused {
fill: #b2dfdb;
stroke: #00695c;
stroke-width: 3;
}
/* Generic computation blocks */
.computation {
fill: #f5f5f5;
stroke: #757575;
stroke-width: 2;
}
/* FP32 precision (alternative red) */
.fp32 {
fill: #ffcdd2;
stroke: #d32f2f;
stroke-width: 2.5;
}
/* Custom styling for sphinx-tabs */
.sphinx-tabs {
margin-bottom: 1rem;
}
.sphinx-tabs-tab {
background-color: #f4f4f4;
border: 1px solid #ccc;
border-bottom: none;
padding: 0.5rem 1rem;
margin-right: 0.5rem;
cursor: pointer;
font-weight: 500;
transition: background-color 0.2s;
}
.sphinx-tabs-tab:hover {
background-color: #e0e0e0;
}
.sphinx-tabs-tab[aria-selected="true"] {
background-color: #76b900; /* NVIDIA green */
color: white;
border-color: #76b900;
margin-right: 0.5rem;
}
.sphinx-tabs-panel {
border: 1px solid #ccc;
padding: 1rem;
background-color: #f9f9f9;
}
/* Dark mode support for RTD theme */
.rst-content .sphinx-tabs-tab {
color: #333;
}
.rst-content .sphinx-tabs-tab[aria-selected="true"] {
color: white;
}
/* Responsive styling for SVG images */
/* Make all SVG images responsive */
.document svg,
.document object[type="image/svg+xml"],
.rst-content svg {
max-width: 100%;
height: auto;
display: block;
margin: 1em auto;
}
/* For raw HTML embedded SVGs */
.document .raw-html svg {
max-width: 100%;
height: auto;
width: 100%;
}
/* Ensure container doesn't overflow */
.document .raw-html {
max-width: 100%;
overflow-x: auto;
}
/* Figure containers with captions */
.svg-figure {
text-align: center;
margin: 20px auto;
}
.svg-figure img {
display: block;
margin: 0 auto;
height: auto;
}
/* Different width classes for figures */
.svg-figure.width-70 img {
width: 70%;
max-width: 100%;
}
.svg-figure.width-80 img {
width: 80%;
max-width: 100%;
}
.svg-figure.width-90 img {
width: 90%;
max-width: 100%;
}
.svg-figure.width-100 img {
width: 100%;
}
/* Figure captions */
.svg-caption {
font-style: italic;
margin-top: 10px;
color: #555;
font-size: 0.95em;
line-height: 1.4;
}
...@@ -67,6 +67,10 @@ ...@@ -67,6 +67,10 @@
overflow: visible !important; overflow: visible !important;
} }
.quant {
background-color: yellow !important;
}
</style> </style>
<style> <style>
a:link, a:visited { a:link, a:visited {
......
...@@ -84,8 +84,11 @@ html_show_sphinx = False ...@@ -84,8 +84,11 @@ html_show_sphinx = False
html_css_files = [ html_css_files = [
"css/nvidia_font.css", "css/nvidia_font.css",
"css/nvidia_footer.css", "css/nvidia_footer.css",
"css/rtabs.css",
"css/output-style.css", "css/output-style.css",
"css/diagram-colors.css",
"css/sphinx_tabs.css",
"css/svg-responsive.css",
"css/rtabs.css",
] ]
html_theme_options = { html_theme_options = {
......
...@@ -15,7 +15,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea ...@@ -15,7 +15,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea
- log the statistics for each of the tensors in every matrix multiply (GEMM) operation, - log the statistics for each of the tensors in every matrix multiply (GEMM) operation,
- run selected GEMMs in higher precision, - run selected GEMMs in higher precision,
- run current scaling - with one scaling factor per tensor - for particular GEMMs, - run current scaling - with one scaling factor per tensor - for particular GEMMs,
- test new precisions and integrate them with FP8 training, - test new precisions and integrate them with quantized training (FP8, NVFP4, etc.),
- ... and many more. - ... and many more.
There are 4 things one needs to do to use Transformer Engine debug features: There are 4 things one needs to do to use Transformer Engine debug features:
......
...@@ -8,7 +8,10 @@ Debug features ...@@ -8,7 +8,10 @@ Debug features
.. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats .. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats
.. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats .. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM .. autoapiclass:: transformer_engine.debug.features.log_nvfp4_tensor_stats.LogNvfp4TensorStats
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer .. autoapiclass:: transformer_engine.debug.features.disable_quantization_gemm.DisableQuantizationGEMM
.. autoapiclass:: transformer_engine.debug.features.disable_quantization_layer.DisableQuantizationLayer
.. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling .. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling
.. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant .. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
\ No newline at end of file
transformers==4.57.0
accelerate==1.10.0
peft==0.15.2
datasets==4.0.0
sentencepiece==0.2.1
...@@ -72,10 +72,15 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer): ...@@ -72,10 +72,15 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
forward pass of the `TransformerLayer`. Also, make sure the output forward pass of the `TransformerLayer`. Also, make sure the output
format matches the output of the HF's `LlamaDecoderLayer`. format matches the output of the HF's `LlamaDecoderLayer`.
""" """
return ( # Handle case where hidden_states might be a tuple (from previous layer output)
super().forward( # This can happen with older versions of HuggingFace transformers
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb if isinstance(hidden_states, tuple):
), hidden_states = hidden_states[0]
# Return tensor directly for HuggingFace transformers >= 4.57
# (older versions wrapped output in tuple and extracted with layer_outputs[0])
return super().forward(
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
) )
...@@ -162,7 +167,7 @@ def replace_params(hf_state_dict, te_state_dict, config): ...@@ -162,7 +167,7 @@ def replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update # collect all layer prefixes to update
all_layer_prefixes = set() all_layer_prefixes = set()
for param_key in hf_state_dict.keys(): for param_key in hf_state_dict.keys():
layer_prefix_pat = "model.layers.\d+." layer_prefix_pat = r"model.layers.\d+."
m = re.match(layer_prefix_pat, param_key) m = re.match(layer_prefix_pat, param_key)
if m is not None: if m is not None:
all_layer_prefixes.add(m.group()) all_layer_prefixes.add(m.group())
......
{ {
"cells": [ "cells": [
{
"cell_type": "markdown",
"id": "6a5b2993",
"metadata": {},
"source": [
"# Accelerating Hugging Face Llama 2 and 3 Fine-Tuning with Transformer Engine\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Goal</b>\n",
"\n",
"This tutorial showcases how to accelerate finetuning a full [Llama 2](https://huggingface.co/meta-llama/Llama-2-7b-hf) or [Llama 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B) models from Hugging Face by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` and `FP8` precisions.\n",
"\n",
"</div>\n"
]
},
{
"cell_type": "markdown",
"id": "331f476a",
"metadata": {},
"source": [
"## Dependencies for this tutorial\n",
"\n",
"Following files and media are necessary to effectively run this tutorial:\n",
"\n",
"1. `te_llama.py`\n",
" - This file contains the code to load a Hugging Face Llama 2 or Llama 3 checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `LlamaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n",
"2. `utils.py`\n",
" - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n",
"3. `media/`\n",
" - This directory contains the images used in the following tutorial.\n",
"\n",
"These packages are necessary to run this tutorial:\n",
"`pytorch`, `transformer_engine`, `accelerate`, `transformers`, `peft`, `datasets`.\n",
"\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Note on running the tutorial with Llama 3 weights</b>\n",
"\n",
"This tutorial shows the cell outputs when run with Llama 2 7B weights. It can be run with Llama 3 8B weights simply by providing the directory with those weights (in Hugging Face format) instead of Llama 2 7B weights. These two models are almost identical, the biggest difference being the model dimension (the smallest Llama 3 model has 8B parameters, whereas the smallest Llama 2 has 7B), which enables this tutorial to work for both of them.\n",
"\n",
"</div>\n"
]
},
{
"cell_type": "markdown",
"id": "44abae4f",
"metadata": {},
"source": [
"## Table of contents\n",
"1. From \"Transformer\" to \"Llama\"\n",
"2. Hugging Face's `LlamaModel`\n",
" - Hugging Face's `LlamaDecoderLayer`\n",
"3. [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n",
"6. [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n",
" - Transformer Engine's `TransformerLayer`\n",
" - `TransformerLayer` options explained\n",
" - Mapping weights from HF's `LlamaDecoderLayer` to TE's `TransformerLayer`\n",
"7. [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n",
"8. Conclusion"
]
},
{
"cell_type": "markdown",
"id": "e37e2cc1",
"metadata": {},
"source": [
"## From \"Transformer\" to \"Llama\" \n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/transformer_llama.png\">\n",
" <figcaption> Fig 1: Llama visualized as a transformer. (generated with [Nvidia's AI-foundation models](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/models/sdxl))</figcaption>\n",
"</figure>\n",
"\n",
"A flashback:\n",
"\n",
"- 2017: [\"Attention Is All You Need\"](https://arxiv.org/abs/1706.03762) paper introduced pioneering \"Transformer\" architecture and changed the NLP field forever.\n",
"- 2018-2020: Emergence of GPT model series that showed causal decoder architectures are great fit for pretraining, few-shot and zero-shot learning.\n",
"- Fast forward to 2023-2024: Following GPT-3/GPT-4 success stories, researchers and companies raced to produce the next best pretrained model that could further be finetuned for application-specific use-cases.\n",
"- February 2023: Meta releases [Llama 2](https://llama.meta.com/llama2) models (Large Language Model Meta AI). \n",
" - These models range from 7B to 70B parameters.\n",
" - LLaMA 2 was pretrained on 2 trillion tokens.\n",
"- April 2024: Meta releases [Llama 3](https://llama.meta.com/llama3) models.\n",
" - These models range from 8B to 70B parameters.\n",
" - LLaMA 3 was pretrained on 15 trillion tokens.\n",
"\n",
"For more information on Llama 2 consider reading the [Huggingface tutorial](https://huggingface.co/blog/llama2). As a quick summary, here are some of the important differences b/w the conventional transformer decoder architecture vs Llama 2 architecture:\n",
"\n",
"1. Decoder only model (causal language modeling and next word prediction)\n",
"2. RMSNorm in place of the LayerNorm\n",
"3. SwiGLU activation function\n",
"4. RoPE as positional embeddings \n",
"5. Grouped Query Attention for the 70B model\n",
"6. Trained on 4K context length\n",
"\n",
"Hugging Face also released a [tutorial about Llama 3](https://huggingface.co/blog/llama3). The key points are:\n",
"\n",
"1. Use of bigger tokenizer - 128256 vs 32K.\n",
"2. Grouped Query Attention is used also by smaller 8B model.\n",
"3. The context length increased to 8K for all models.\n",
"3. Llama 3 was trained on 8x more data than Llama 2.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/transformer_vs_llama.svg\">\n",
" <figcaption> Fig 2: Comparing GPT and Llama architectures. </figcaption>\n",
"</figure>"
]
},
{
"cell_type": "markdown",
"id": "a110de1a",
"metadata": {},
"source": [
"## Hugging Face's `LlamaModel`\n",
"Hugging Face provides an open-source implementation of `Llama` model in [modeling_llama.py](https://github.com/huggingface/transformers/blob/3d2900e829ab16757632f9dde891f1947cfc4be0/src/transformers/models/llama/modeling_llama.py#L4).\n",
"\n",
"Here's a block diagram that shows how Llama model is implemented in the Hugging Face repo. Notice the modular encapsulated form and `LlamaDecoderLayer` at the core of the model implementation.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/llama_for_causal_lm.svg\">\n",
" <figcaption> Fig 3: Causal Llama Model Block Diagram. </figcaption>\n",
"</figure>\n",
"\n",
"The above diagram translates to the following text output of the model in PyTorch. Notice that the core of the model has 32 `LlamaDecoderLayer`s. \n",
"\n",
"```\n",
"LlamaForCausalLM(\n",
" (model): LlamaModel(\n",
" (embed_tokens): Embedding(32000, 4096, padding_idx=0)\n",
" (layers): ModuleList(\n",
" (0-31): 32 x LlamaDecoderLayer(\n",
" (self_attn): LlamaFlashAttention2(\n",
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (rotary_emb): LlamaRotaryEmbedding()\n",
" )\n",
" (mlp): LlamaMLP(\n",
" (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
" (act_fn): SiLU()\n",
" )\n",
" (input_layernorm): LlamaRMSNorm()\n",
" (post_attention_layernorm): LlamaRMSNorm()\n",
" )\n",
" )\n",
" (norm): LlamaRMSNorm()\n",
" )\n",
" (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
")\n",
"```\n",
"\n",
"#### Hugging Face's `LlamaDecoderLayer`\n",
"\n",
"Let's take a closer look at `LlamaDecoderLayer`. It is composed of `input_layernorm`, `self_attn`, `post_attention_layernorm` and `mlp` modules. Each module has associated weights as shown in the diagram.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/llama_zoom.svg\">\n",
" <figcaption> Fig 4: Causal Llama Model Block Diagram (with simplified illustration of the [LlamaDecoderLayer](https://github.com/huggingface/transformers/blob/e770f0316d2a9b787c9d1440f204fcb65e176682/src/transformers/models/llama/modeling_llama.py#L695)). </figcaption>\n",
"</figure>\n",
"\n",
"##### Self_Attn Layer\n",
"For simplicity in the block diagram illustration of the \"self_attn\" box, we omit the \"Grouped Query Attention\" operation and only showcase the modules which have associated weights.\n",
" \n",
"##### MLP Layer\n",
"\n",
"SwiGLU is an activation defined as follows in the [modeling_llama.py](https://github.com/huggingface/transformers/blob/7c4995f93d8d24aae05e1e43279c96dce736e5c8/src/transformers/models/llama/modeling_llama.py#L236) file in the Hugging Face github repo:\n",
"```\n",
"\"\"\"\n",
"1. `self.up_proj`, `self.gate_proj` and `self.down_proj` are \"Linear\" layers\n",
"2. `self.act_fn` is a \"Swish\" function\n",
"\n",
"\"\"\"\n",
"down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n",
"```\n",
"It requires a set of 3 weights as compared to 2 weights in conventional \"MLP\" layers e.g. in the traditional transformer or GPT architectures. This is also illustrated in the following figure:\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/swiglu.svg\">\n",
" <figcaption> Fig 5: A look inside the feedforward layer with <code>swiglu</code> activation function. </figcaption>\n",
"</figure>"
]
},
{
"cell_type": "markdown",
"id": "c9529229",
"metadata": {},
"source": [
"## [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n",
"\n",
"Llama 2 weights are loaded into the Hugging Face native implementation `LlamaForCausalLM` (refer to [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)). \n",
"\n",
"For this and other subsequent runs, the `batch_size` is `8`. The `LlamaDecoderLayer` is left unchanged in the baseline as follows:\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/llamadecoderlayer.svg\">\n",
" <figcaption> Fig 6: Revisiting \"LlamaDecoderLayer\". </figcaption>\n",
"</figure>\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
"\n",
"The baseline implementation will be run in `BF16` precision.\n",
"\n",
"</div>"
]
},
{
"cell_type": "markdown",
"id": "b38eb3ac",
"metadata": {},
"source": [
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Note</b>\n",
" \n",
"This tutorial loads and trains a Llama 3 8B or a Llama 2 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n",
"\n",
"If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n",
"\n",
"</div>\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2e9d7a8c",
"metadata": {},
"outputs": [
{ {
"name": "stdout", "cell_type": "markdown",
"output_type": "stream", "metadata": {},
"text": [ "source": [
"10 finetuning steps complete!\n", "# Accelerating Hugging Face Llama 2 and 3 Fine-Tuning with Transformer Engine\n",
"Average time taken per step: 248 milliseconds\n" "\n",
] "<div class=\"alert alert-info\">\n",
} "\n",
], "<b>Goal</b>\n",
"source": [ "\n",
"# Restart the notebook (to flush the GPU memory)\n", "This tutorial showcases how to accelerate finetuning a full [Llama 2](https://huggingface.co/meta-llama/Llama-2-7b-hf) or [Llama 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B) models from Hugging Face by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` and `FP8` precisions.\n",
"from utils import restart_jupyter_notebook\n", "\n",
"restart_jupyter_notebook()\n", "</div>\n"
"\n", ],
"\n", "id": "6a5b2993"
"# Import necessary packages, methods and variables\n", },
"from utils import *\n",
"\n",
"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
"\n",
"# Init the model and accelerator wrapper\n",
"model = init_baseline_model(hyperparams)\n",
"accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n",
"\n",
"\n",
"# Finetune the model\n",
"finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)"
]
},
{
"cell_type": "markdown",
"id": "4035ccb7",
"metadata": {},
"source": [
"Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n",
"\n",
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 248 | 1 |"
]
},
{
"cell_type": "markdown",
"id": "3db90dff",
"metadata": {},
"source": [
"## [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n",
"\n",
"In addition to basic layers like `Linear` and `LayerNorm`, Transformer Engine offers larger modules like `MultiheadAttention` (combines \"LayerNorm\" and \"Self Attention\") and `LayerNormMLP` (combines \"LayerNorm\" and \"MLP\") that could replace their counterparts in the `LlamaDecoderLayer` and potentially provide a speedup. Transformer Engine also offers a full `TransformerLayer` (which further combines `MultiheadAttention` and `LayerNormMLP` layers) which could replace `LlamaDecoderLayer` and provide a speedup (with careful mapping of the weights since the name of the weights are different for those two layers). Let's take a closer look at Transformer Engine's `TransformerLayer`. \n",
"\n",
"#### Transformer Engine's `TransformerLayer`\n",
"\n",
"At a higher level, TE's `TransformerLayer` could be visualized as an apt replacement for the `LlamaDecoderLayer`. But the internals of the `TransformerLayer` are organized a bit differently. \n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/tellamadecoderlayer.svg\">\n",
" <figcaption> Fig 7: Transformer Engine's `TransformerLayer` </figcaption>\n",
"</figure>\n",
"\n",
"Just like Hugging Face's `LlamaDecoderLayer`, Transformer Engine's `TransformerLayer` encapsulates `self_attention` (as `MultiheadAttention`) and `mlp` (as `LayerNormMLP`). A major difference is that the two `Norm`s are included in the `MultiheadAttention` and `LayerNormMLP` layers as shown in the following output prompt:\n",
"\n",
"```\n",
"TransformerLayer(\n",
" (self_attention): MultiheadAttention(\n",
" (layernorm_qkv): LayerNormLinear()\n",
" (core_attention): DotProductAttention()\n",
" (proj): Linear()\n",
" )\n",
" (layernorm_mlp): LayerNormMLP()\n",
")\n",
"```\n",
"\n",
"Another difference is that Transformer Engine implements an efficient version of feedforward layer with SwiGLU in which the weights from the `up_proj` and `gate_proj` modules are merged together and SwiGLU is applied using a custom fused kernel. This is done so that only one big and efficient Matrix Multiplication operation is issued to the GPU instead of two smaller ones.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/swiglu_te.svg\">\n",
" <figcaption> Fig 8: Abstract illustration of the SwiGLU implementation in Transformer Engine. </figcaption>\n",
"</figure>\n",
"\n",
"#### `TransformerLayer` options explained\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Note</b>\n",
" \n",
"Here, we go over some of the options in `TransformerLayer` that are needed for the tutorial. For a complete list of options, refer the [TransformerLayer API documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html?highlight=transformerlayer#transformer_engine.pytorch.TransformerLayer).\n",
"\n",
"</div>\n",
"\n",
"In the accompanying `te_llama.py` file, `TELlamaDecoderLayer` is defined as a wrapper over TE's `TransformerLayer` with a few needed options that make `TransformerLayer` a plug-in replacement for the HF's `LlamaDecoderLayer`.\n",
"\n",
"```\n",
"class TELlamaDecoderLayer(te.pytorch.TransformerLayer):\n",
" def __init__(self, config):\n",
" super().__init__(\n",
" config.hidden_size,\n",
" config.intermediate_size,\n",
" config.num_attention_heads,\n",
" bias=False,\n",
" layernorm_epsilon=config.rms_norm_eps,\n",
" hidden_dropout=0,\n",
" attention_dropout=0,\n",
" fuse_qkv_params=False,\n",
" normalization=\"RMSNorm\",\n",
" activation=\"swiglu\",\n",
" attn_input_format=\"bshd\",\n",
" num_gqa_groups=config.num_key_value_heads,\n",
" )\n",
" te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)\n",
" self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()\n",
"```\n",
"\n",
"Here's a list summarizing each option briefly:\n",
"\n",
"1. `hidden_size`: size of each input sample.\n",
"2. `ffn_hidden_size`: intermediate size to which samples are projected.\n",
"3. `num_attention_heads`: number of attention heads in the transformer layer.\n",
"4. `bias`: switch to add additive biases to the submodule layers.\n",
"5. `layernorm_epsilon`: a value added to the denominator of layer normalization for numerical stability. Default is `1e-5`.\n",
"6. `hidden_dropout`: dropout probability for the dropout op after FC2 layer (fully connected layer no. 2). Default is `0.1`.\n",
"7. `attention_dropout`: dropout probability for the dropout op during multi-head attention. Default is `0.1`. \n",
"8. `fuse_qkv_params`: if set to True, TransformerLayer module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument fuse_wgrad_accumulation.\n",
"9. `normalization`: type of normalization applied. Default is `LayerNorm`.\n",
"10. `activation`: type of activation used in the MLP block. Default is `gelu`.\n",
"11. `attn_input_format`: controls whether the dimensions of the intermediate hidden states is 'batch first' ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, `b` batch size, `h` the number of heads, `d` head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules.\n",
"12. `num_gqa_groups`: number of GQA groups in the transformer layer. Grouped Query Attention is described in [this paper](https://arxiv.org/pdf/2305.13245.pdf). This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention ([MQA](https://arxiv.org/pdf/1911.02150.pdf)), while GQA-H is equivalent to MultiHead Attention, i.e. `num_gqa_groups = num_attention_heads`.\n",
"\n",
"\n",
"Further, note that `RotaryPositionEmbedding` is defined as part of the `TELlamaDecoderLayer` (wrapper around TE's `TransformerLayer`) itself since it expects this rope cache if RoPE is used in the model. \n",
"\n",
"Let's revisit how `LlamaDecoderLayer`s form the core of the decoder layer stack in HF's llama implementation:\n",
"```\n",
"ModuleList(\n",
" (0-31): 32 x LlamaDecoderLayer(\n",
" (self_attn): LlamaAttention(\n",
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (rotary_emb): LlamaRotaryEmbedding()\n",
" )\n",
" (mlp): LlamaMLP(\n",
" (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
" (act_fn): SiLU()\n",
" )\n",
" (input_layernorm): LlamaRMSNorm()\n",
" (post_attention_layernorm): LlamaRMSNorm()\n",
" )\n",
")\n",
"```\n",
"\n",
"A major portion of the Hugging Face model implementation (32 `LlamaDecoderLayer` layers) could be potentially replaced with Transformer Engine's `TransformerLayer` layers. Let's see how it is made possible.\n",
"\n",
"\n",
"#### Mapping weights from HF's `LlamaDecoderLayer` to TE's `TransformerLayer`\n",
"\n",
"Refer the accompanying file `te_llama.py` which provides a reference to create a Llama 2 model with TE's `TransformerLayer` after replacing HF's `LlamaDecoderLayer`.\n",
"\n",
"Briefly, following pieces of code are put together:\n",
"\n",
"1. `TELlamaDecoderLayer` is added as a wrapper for `TransformerLayer`. \n",
"```\n",
"class TELlamaDecoderLayer(te.pytorch.TransformerLayer):\n",
" \"\"\"\n",
" Wrapper class over TE's `TransformerLayer`. This makes the wrapper very\n",
" similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.\n",
"\n",
" Args:\n",
" config: LlamaConfig\n",
" args: positional args (for compatibility with `LlamaDecoderLayer`)\n",
" kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)\n",
" \"\"\"\n",
" def __init__(self, config, *args, **kwargs):\n",
" super().__init__(\n",
" hidden_size=config.hidden_size,\n",
" ffn_hidden_size=config.intermediate_size,\n",
" num_attention_heads=config.num_attention_heads,\n",
" bias=False,\n",
" layernorm_epsilon=config.rms_norm_eps,\n",
" hidden_dropout=0,\n",
" attention_dropout=0,\n",
" fuse_qkv_params=False,\n",
" normalization=\"RMSNorm\",\n",
" activation=\"swiglu\",\n",
" attn_input_format=\"bshd\",\n",
" )\n",
" te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)\n",
" self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()\n",
"\n",
" def forward(self,\n",
" hidden_states,\n",
" *args,\n",
" attention_mask,\n",
" **kwargs):\n",
" \"\"\"\n",
" Custom forward to make sure we only pass relevant arguments to the\n",
" forward pass of the `TransformerLayer`. Also, make sure the output\n",
" format matches the output of the HF's `LlamaDecoderLayer`.\n",
" \"\"\"\n",
" return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),)\n",
"```\n",
"\n",
"2. Before creating a `LlamaForCausalLM`, `replace_decoder` context manager is used to monkey-patch `LlamaDecoderLayer` with `TELlamaDecoderLayer`.\n",
"\n",
"```\n",
"@contextmanager\n",
"def replace_decoder(te_decoder_cls):\n",
" \"\"\"\n",
" Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.\n",
" \"\"\"\n",
" original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer\n",
" transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls\n",
" try:\n",
" yield\n",
" finally:\n",
" transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls\n",
".\n",
".\n",
".\n",
"class TELlamaForCausalLM:\n",
" \"\"\"\n",
" Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`\n",
" class is monkey-patched with `TELlamaDecoderLayer` class before\n",
" initializing the causal LM with `LlamaForCausalLM`.\n",
"\n",
" Args:\n",
" config: LlamaConfig\n",
" \"\"\"\n",
"\n",
" def __new__(cls, config: LlamaConfig):\n",
" with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):\n",
" llama_for_causal_lm = LlamaForCausalLM(config)\n",
" return llama_for_causal_lm\n",
".\n",
".\n",
".\n",
"```\n",
"\n",
"3. A custom `pretrained_from_local` method is added that copies the weights from the checkpoint (which is meant for HF Llama implementation) to the modified `TELlamaForCausalLM` by carefully mapping the weights from the `LlamaDecoderLayer` (HF) to `TransformerLayer` (TE). The method `replace_params` maps and copies apt weights from `LlamaDecoderLayer` to the `TransformerLayer`. Refer to the following diagram for more details.\n",
"\n",
"```\n",
"def replace_params(hf_state_dict, te_state_dict):\n",
" # collect all layer prefixes to update\n",
" all_layer_prefixes = set()\n",
" for param_key in hf_state_dict.keys():\n",
" layer_prefix_pat = 'model.layers.\\d+.'\n",
" m = re.match(layer_prefix_pat, param_key)\n",
" if m is not None:\n",
" all_layer_prefixes.add(m.group())\n",
"\n",
" for layer_prefix in all_layer_prefixes:\n",
" # When loading weights into models with less number of layers, skip the\n",
" # copy if the corresponding layer doesn't exist in TE model\n",
" if layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' in te_state_dict:\n",
" te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]\n",
"\n",
" if layer_prefix + 'self_attention.layernorm_qkv.query_weight' in te_state_dict:\n",
" te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]\n",
"\n",
" if layer_prefix + 'self_attention.layernorm_qkv.key_weight' in te_state_dict:\n",
" te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]\n",
" .\n",
" .\n",
" .\n",
"\n",
" return all_layer_prefixes\n",
"```\n",
"\n",
"The following figure shows how the weights get mapped from the HF's `LlamaDecoderLayer` to TE's `TransformerLayer`.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/weight_swap.svg\">\n",
" <figcaption> Fig 9: Replace `LlamaDecoderLayer` with `TransformerLayer`. </figcaption>\n",
"</figure>\n",
"\n",
"After initializing the modified Llama model this way, the core decoder layers get changed to `TELlamaDecoderLayer` (wrapper around `TransformerLayer`) as shown in the following output:\n",
"```\n",
"ModuleList(\n",
" (0-31): 32 x TELlamaDecoderLayer(\n",
" (self_attention): MultiheadAttention(\n",
" (layernorm_qkv): LayerNormLinear()\n",
" (core_attention): DotProductAttention(\n",
" (flash_attention): FlashAttention()\n",
" (fused_attention): FusedAttention()\n",
" (unfused_attention): UnfusedDotProductAttention(\n",
" (scale_mask_softmax): FusedScaleMaskSoftmax()\n",
" (attention_dropout): Dropout(p=0, inplace=False)\n",
" )\n",
" )\n",
" (proj): Linear()\n",
" )\n",
" (layernorm_mlp): LayerNormMLP()\n",
" )\n",
")\n",
"```\n",
"\n",
"In summary, the model gets changed as follows with a large chunk of the implementation (core decoder layers) coming from Transformer Engine.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/model_change.svg\">\n",
" <figcaption> Fig 10: Language model after the HF's `LlamaDecoderLayer`s are replaced with TE's `TransformerLayer`s. </figcaption>\n",
"</figure>\n",
"\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
"\n",
"Let's first run this \"TELlama\" implementation in `BF16` precision.\n",
"</div>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "bdb34b91",
"metadata": {},
"outputs": [
{ {
"name": "stdout", "cell_type": "markdown",
"output_type": "stream", "metadata": {},
"text": [ "source": [
"10 finetuning steps complete!\n", "## Dependencies for this tutorial\n",
"Average time taken per step: 185 milliseconds\n" "\n",
] "Following files and media are necessary to effectively run this tutorial:\n",
} "\n",
], "1. `te_llama.py`\n",
"source": [ " - This file contains the code to load a Hugging Face Llama 2 or Llama 3 checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `LlamaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n",
"# Restart the notebook (to flush the GPU memory)\n", "2. `utils.py`\n",
"from utils import restart_jupyter_notebook\n", " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n",
"restart_jupyter_notebook()\n", "3. `requirements.txt`\n",
"\n", " - This file contains the necessary Python packages for this tutorial.\n",
"\n", "4. `media/`\n",
"# Import necessary packages, methods and variables\n", " - This directory contains the images used in the following tutorial.\n",
"from utils import *\n", "\n",
"\n", "\n",
"\n", "<div class=\"alert alert-info\">\n",
"# Provide Huggingface Access Token\n", "\n",
"hyperparams.hf_access_token = \"\"\n", "<b>Note on running the tutorial with Llama 3 weights</b>\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n", "\n",
"\n", "This tutorial shows the cell outputs when run with Llama 2 7B weights. It can be run with Llama 3 8B weights simply by providing the directory with those weights (in Hugging Face format) instead of Llama 2 7B weights. These two models are almost identical, the biggest difference being the model dimension (the smallest Llama 3 model has 8B parameters, whereas the smallest Llama 2 has 7B), which enables this tutorial to work for both of them.\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n", "\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", "</div>\n",
"hyperparams.weights_cache_dir = \"\"\n", ""
"\n", ],
"# For Llama 2, uncomment this line (also set by default)\n", "id": "331f476a"
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n", },
"\n", {
"# For Llama 3, uncomment this line\n", "cell_type": "markdown",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n", "metadata": {},
"\n", "source": [
"hyperparams.mixed_precision = \"bf16\"\n", "### Setup\n",
"\n", "\n",
"\n", "Install the required Python packages using the following command:"
"# Init the model and accelerator wrapper\n", ],
"model = init_te_llama_model(hyperparams)\n", "id": "b56526b3"
"accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", },
"\n", {
"\n", "cell_type": "code",
"# Finetune the model\n", "metadata": {},
"finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" "source": [
] "# Uncomment and run this cell when running the tutorial for the first time\n",
}, "# %pip install -r requirements.txt"
{ ],
"cell_type": "markdown", "id": "099697e2",
"id": "0c9fbd65", "execution_count": null,
"metadata": {}, "outputs": []
"source": [ },
"Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `LlamaDecoderLayer` gives a speedup of **34%** even when using only BF16 precision!\n",
"\n",
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 248 | 1 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 185 | 1.34 |"
]
},
{
"cell_type": "markdown",
"id": "98cd8efb",
"metadata": {},
"source": [
"## [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n",
"\n",
"Now that most of the HF Llama model implementation (`LlamaDecoderLayer`s) has been swapped with Transformer Engine implementation (`TELlamaDecoderLayer` or `TransformerLayer`), let's see how finetuning in `FP8` precision helps improve performance.\n",
"\n",
"#### How to run the model in `FP8` precision\n",
"\n",
"After the substitution, the model can be run in `FP8` precision by the following change over the previous BF16 runs. (For more information, refer the corresponding `wrap_with_accelerator` function in the accompanying `utils.py` file).\n",
"\n",
"```\n",
"# Specify the `FP8RecipeKwargs` (additional argument required to run in `fp8` precision)\n",
"fp8_kwarg_handler = [FP8RecipeKwargs(backend=\"te\")]\n",
"\n",
"# Pass the `FP8RecipeKwargs` to the `Accelerator` init call\n",
"accelerator = Accelerator(\n",
" ...\n",
" kwargs_handlers=fp8_kwarg_handler\n",
")\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "772c6f22",
"metadata": {},
"outputs": [
{ {
"name": "stdout", "cell_type": "markdown",
"output_type": "stream", "metadata": {},
"text": [ "source": [
"10 finetuning steps complete!\n", "## Table of contents\n",
"Average time taken per step: 160 milliseconds\n" "1. From \"Transformer\" to \"Llama\"\n",
] "2. Hugging Face's `LlamaModel`\n",
" - Hugging Face's `LlamaDecoderLayer`\n",
"3. [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n",
"6. [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n",
" - Transformer Engine's `TransformerLayer`\n",
" - `TransformerLayer` options explained\n",
" - Mapping weights from HF's `LlamaDecoderLayer` to TE's `TransformerLayer`\n",
"7. [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n",
"8. Conclusion"
],
"id": "44abae4f"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## From \"Transformer\" to \"Llama\" \n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/transformer_llama.png\">\n",
" <figcaption> Fig 1: Llama visualized as a transformer. (generated with [Nvidia's AI-foundation models](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/models/sdxl))</figcaption>\n",
"</figure>\n",
"\n",
"A flashback:\n",
"\n",
"- 2017: [\"Attention Is All You Need\"](https://arxiv.org/abs/1706.03762) paper introduced pioneering \"Transformer\" architecture and changed the NLP field forever.\n",
"- 2018-2020: Emergence of GPT model series that showed causal decoder architectures are great fit for pretraining, few-shot and zero-shot learning.\n",
"- Fast forward to 2023-2024: Following GPT-3/GPT-4 success stories, researchers and companies raced to produce the next best pretrained model that could further be finetuned for application-specific use-cases.\n",
"- February 2023: Meta releases [Llama 2](https://llama.meta.com/llama2) models (Large Language Model Meta AI). \n",
" - These models range from 7B to 70B parameters.\n",
" - LLaMA 2 was pretrained on 2 trillion tokens.\n",
"- April 2024: Meta releases [Llama 3](https://llama.meta.com/llama3) models.\n",
" - These models range from 8B to 70B parameters.\n",
" - LLaMA 3 was pretrained on 15 trillion tokens.\n",
"\n",
"For more information on Llama 2 consider reading the [Huggingface tutorial](https://huggingface.co/blog/llama2). As a quick summary, here are some of the important differences b/w the conventional transformer decoder architecture vs Llama 2 architecture:\n",
"\n",
"1. Decoder only model (causal language modeling and next word prediction)\n",
"2. RMSNorm in place of the LayerNorm\n",
"3. SwiGLU activation function\n",
"4. RoPE as positional embeddings \n",
"5. Grouped Query Attention for the 70B model\n",
"6. Trained on 4K context length\n",
"\n",
"Hugging Face also released a [tutorial about Llama 3](https://huggingface.co/blog/llama3). The key points are:\n",
"\n",
"1. Use of bigger tokenizer - 128256 vs 32K.\n",
"2. Grouped Query Attention is used also by smaller 8B model.\n",
"3. The context length increased to 8K for all models.\n",
"3. Llama 3 was trained on 8x more data than Llama 2.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/transformer_vs_llama.svg\">\n",
" <figcaption> Fig 2: Comparing GPT and Llama architectures. </figcaption>\n",
"</figure>"
],
"id": "e37e2cc1"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hugging Face's `LlamaModel`\n",
"Hugging Face provides an open-source implementation of `Llama` model in [modeling_llama.py](https://github.com/huggingface/transformers/blob/3d2900e829ab16757632f9dde891f1947cfc4be0/src/transformers/models/llama/modeling_llama.py#L4).\n",
"\n",
"Here's a block diagram that shows how Llama model is implemented in the Hugging Face repo. Notice the modular encapsulated form and `LlamaDecoderLayer` at the core of the model implementation.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/llama_for_causal_lm.svg\">\n",
" <figcaption> Fig 3: Causal Llama Model Block Diagram. </figcaption>\n",
"</figure>\n",
"\n",
"The above diagram translates to the following text output of the model in PyTorch. Notice that the core of the model has 32 `LlamaDecoderLayer`s. \n",
"\n",
"```\n",
"LlamaForCausalLM(\n",
" (model): LlamaModel(\n",
" (embed_tokens): Embedding(32000, 4096, padding_idx=0)\n",
" (layers): ModuleList(\n",
" (0-31): 32 x LlamaDecoderLayer(\n",
" (self_attn): LlamaFlashAttention2(\n",
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (rotary_emb): LlamaRotaryEmbedding()\n",
" )\n",
" (mlp): LlamaMLP(\n",
" (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
" (act_fn): SiLU()\n",
" )\n",
" (input_layernorm): LlamaRMSNorm()\n",
" (post_attention_layernorm): LlamaRMSNorm()\n",
" )\n",
" )\n",
" (norm): LlamaRMSNorm()\n",
" )\n",
" (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
")\n",
"```\n",
"\n",
"### Hugging Face's `LlamaDecoderLayer`\n",
"\n",
"Let's take a closer look at `LlamaDecoderLayer`. It is composed of `input_layernorm`, `self_attn`, `post_attention_layernorm` and `mlp` modules. Each module has associated weights as shown in the diagram.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/llama_zoom.svg\">\n",
" <figcaption> Fig 4: Causal Llama Model Block Diagram (with simplified illustration of the [LlamaDecoderLayer](https://github.com/huggingface/transformers/blob/e770f0316d2a9b787c9d1440f204fcb65e176682/src/transformers/models/llama/modeling_llama.py#L695)). </figcaption>\n",
"</figure>\n",
"\n",
"#### Self_Attn Layer\n",
"For simplicity in the block diagram illustration of the \"self_attn\" box, we omit the \"Grouped Query Attention\" operation and only showcase the modules which have associated weights.\n",
" \n",
"#### MLP Layer\n",
"\n",
"SwiGLU is an activation defined as follows in the [modeling_llama.py](https://github.com/huggingface/transformers/blob/7c4995f93d8d24aae05e1e43279c96dce736e5c8/src/transformers/models/llama/modeling_llama.py#L236) file in the Hugging Face github repo:\n",
"```\n",
"\"\"\"\n",
"1. `self.up_proj`, `self.gate_proj` and `self.down_proj` are \"Linear\" layers\n",
"2. `self.act_fn` is a \"Swish\" function\n",
"\n",
"\"\"\"\n",
"down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n",
"```\n",
"It requires a set of 3 weights as compared to 2 weights in conventional \"MLP\" layers e.g. in the traditional transformer or GPT architectures. This is also illustrated in the following figure:\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/swiglu.svg\">\n",
" <figcaption> Fig 5: A look inside the feedforward layer with <code>swiglu</code> activation function. </figcaption>\n",
"</figure>"
],
"id": "a110de1a"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n",
"\n",
"Llama 2 weights are loaded into the Hugging Face native implementation `LlamaForCausalLM` (refer to [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)). \n",
"\n",
"For this and other subsequent runs, the `batch_size` is `8`. The `LlamaDecoderLayer` is left unchanged in the baseline as follows:\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/llamadecoderlayer.svg\">\n",
" <figcaption> Fig 6: Revisiting \"LlamaDecoderLayer\". </figcaption>\n",
"</figure>\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
"\n",
"The baseline implementation will be run in `BF16` precision.\n",
"\n",
"</div>"
],
"id": "c9529229"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Note</b>\n",
" \n",
"This tutorial loads and trains a Llama 3 8B or a Llama 2 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n",
"\n",
"If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n",
"\n",
"</div>\n"
],
"id": "b38eb3ac"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
"\n",
"# Init the model and accelerator wrapper\n",
"model = init_baseline_model(hyperparams)\n",
"accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n",
"\n",
"\n",
"# Finetune the model\n",
"finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"10 finetuning steps complete!\n",
"Average time taken per step: 248 milliseconds\n"
]
}
],
"id": "2e9d7a8c"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n",
"\n",
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 248 | 1 |"
],
"id": "4035ccb7"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n",
"\n",
"In addition to basic layers like `Linear` and `LayerNorm`, Transformer Engine offers larger modules like `MultiheadAttention` (combines \"LayerNorm\" and \"Self Attention\") and `LayerNormMLP` (combines \"LayerNorm\" and \"MLP\") that could replace their counterparts in the `LlamaDecoderLayer` and potentially provide a speedup. Transformer Engine also offers a full `TransformerLayer` (which further combines `MultiheadAttention` and `LayerNormMLP` layers) which could replace `LlamaDecoderLayer` and provide a speedup (with careful mapping of the weights since the name of the weights are different for those two layers). Let's take a closer look at Transformer Engine's `TransformerLayer`. \n",
"\n",
"### Transformer Engine's `TransformerLayer`\n",
"\n",
"At a higher level, TE's `TransformerLayer` could be visualized as an apt replacement for the `LlamaDecoderLayer`. But the internals of the `TransformerLayer` are organized a bit differently. \n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/tellamadecoderlayer.svg\">\n",
" <figcaption> Fig 7: Transformer Engine's `TransformerLayer` </figcaption>\n",
"</figure>\n",
"\n",
"Just like Hugging Face's `LlamaDecoderLayer`, Transformer Engine's `TransformerLayer` encapsulates `self_attention` (as `MultiheadAttention`) and `mlp` (as `LayerNormMLP`). A major difference is that the two `Norm`s are included in the `MultiheadAttention` and `LayerNormMLP` layers as shown in the following output prompt:\n",
"\n",
"```\n",
"TransformerLayer(\n",
" (self_attention): MultiheadAttention(\n",
" (layernorm_qkv): LayerNormLinear()\n",
" (core_attention): DotProductAttention()\n",
" (proj): Linear()\n",
" )\n",
" (layernorm_mlp): LayerNormMLP()\n",
")\n",
"```\n",
"\n",
"Another difference is that Transformer Engine implements an efficient version of feedforward layer with SwiGLU in which the weights from the `up_proj` and `gate_proj` modules are merged together and SwiGLU is applied using a custom fused kernel. This is done so that only one big and efficient Matrix Multiplication operation is issued to the GPU instead of two smaller ones.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/swiglu_te.svg\">\n",
" <figcaption> Fig 8: Abstract illustration of the SwiGLU implementation in Transformer Engine. </figcaption>\n",
"</figure>\n",
"\n",
"### `TransformerLayer` options explained\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Note</b>\n",
" \n",
"Here, we go over some of the options in `TransformerLayer` that are needed for the tutorial. For a complete list of options, refer the [TransformerLayer API documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html?highlight=transformerlayer#transformer_engine.pytorch.TransformerLayer).\n",
"\n",
"</div>\n",
"\n",
"In the accompanying `te_llama.py` file, `TELlamaDecoderLayer` is defined as a wrapper over TE's `TransformerLayer` with a few needed options that make `TransformerLayer` a plug-in replacement for the HF's `LlamaDecoderLayer`.\n",
"\n",
"```\n",
"class TELlamaDecoderLayer(te.pytorch.TransformerLayer):\n",
" def __init__(self, config):\n",
" super().__init__(\n",
" config.hidden_size,\n",
" config.intermediate_size,\n",
" config.num_attention_heads,\n",
" bias=False,\n",
" layernorm_epsilon=config.rms_norm_eps,\n",
" hidden_dropout=0,\n",
" attention_dropout=0,\n",
" fuse_qkv_params=False,\n",
" normalization=\"RMSNorm\",\n",
" activation=\"swiglu\",\n",
" attn_input_format=\"bshd\",\n",
" num_gqa_groups=config.num_key_value_heads,\n",
" )\n",
" te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)\n",
" self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()\n",
"```\n",
"\n",
"Here's a list summarizing each option briefly:\n",
"\n",
"1. `hidden_size`: size of each input sample.\n",
"2. `ffn_hidden_size`: intermediate size to which samples are projected.\n",
"3. `num_attention_heads`: number of attention heads in the transformer layer.\n",
"4. `bias`: switch to add additive biases to the submodule layers.\n",
"5. `layernorm_epsilon`: a value added to the denominator of layer normalization for numerical stability. Default is `1e-5`.\n",
"6. `hidden_dropout`: dropout probability for the dropout op after FC2 layer (fully connected layer no. 2). Default is `0.1`.\n",
"7. `attention_dropout`: dropout probability for the dropout op during multi-head attention. Default is `0.1`. \n",
"8. `fuse_qkv_params`: if set to True, TransformerLayer module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument fuse_wgrad_accumulation.\n",
"9. `normalization`: type of normalization applied. Default is `LayerNorm`.\n",
"10. `activation`: type of activation used in the MLP block. Default is `gelu`.\n",
"11. `attn_input_format`: controls whether the dimensions of the intermediate hidden states is 'batch first' ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, `b` batch size, `h` the number of heads, `d` head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules.\n",
"12. `num_gqa_groups`: number of GQA groups in the transformer layer. Grouped Query Attention is described in [this paper](https://arxiv.org/pdf/2305.13245.pdf). This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention ([MQA](https://arxiv.org/pdf/1911.02150.pdf)), while GQA-H is equivalent to MultiHead Attention, i.e. `num_gqa_groups = num_attention_heads`.\n",
"\n",
"\n",
"Further, note that `RotaryPositionEmbedding` is defined as part of the `TELlamaDecoderLayer` (wrapper around TE's `TransformerLayer`) itself since it expects this rope cache if RoPE is used in the model. \n",
"\n",
"Let's revisit how `LlamaDecoderLayer`s form the core of the decoder layer stack in HF's llama implementation:\n",
"```\n",
"ModuleList(\n",
" (0-31): 32 x LlamaDecoderLayer(\n",
" (self_attn): LlamaAttention(\n",
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (rotary_emb): LlamaRotaryEmbedding()\n",
" )\n",
" (mlp): LlamaMLP(\n",
" (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
" (act_fn): SiLU()\n",
" )\n",
" (input_layernorm): LlamaRMSNorm()\n",
" (post_attention_layernorm): LlamaRMSNorm()\n",
" )\n",
")\n",
"```\n",
"\n",
"A major portion of the Hugging Face model implementation (32 `LlamaDecoderLayer` layers) could be potentially replaced with Transformer Engine's `TransformerLayer` layers. Let's see how it is made possible.\n",
"\n",
"\n",
"### Mapping weights from HF's `LlamaDecoderLayer` to TE's `TransformerLayer`\n",
"\n",
"Refer the accompanying file `te_llama.py` which provides a reference to create a Llama 2 model with TE's `TransformerLayer` after replacing HF's `LlamaDecoderLayer`.\n",
"\n",
"Briefly, following pieces of code are put together:\n",
"\n",
"1. `TELlamaDecoderLayer` is added as a wrapper for `TransformerLayer`. \n",
"```\n",
"class TELlamaDecoderLayer(te.pytorch.TransformerLayer):\n",
" \"\"\"\n",
" Wrapper class over TE's `TransformerLayer`. This makes the wrapper very\n",
" similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.\n",
"\n",
" Args:\n",
" config: LlamaConfig\n",
" args: positional args (for compatibility with `LlamaDecoderLayer`)\n",
" kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)\n",
" \"\"\"\n",
" def __init__(self, config, *args, **kwargs):\n",
" super().__init__(\n",
" hidden_size=config.hidden_size,\n",
" ffn_hidden_size=config.intermediate_size,\n",
" num_attention_heads=config.num_attention_heads,\n",
" bias=False,\n",
" layernorm_epsilon=config.rms_norm_eps,\n",
" hidden_dropout=0,\n",
" attention_dropout=0,\n",
" fuse_qkv_params=False,\n",
" normalization=\"RMSNorm\",\n",
" activation=\"swiglu\",\n",
" attn_input_format=\"bshd\",\n",
" )\n",
" te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)\n",
" self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()\n",
"\n",
" def forward(self,\n",
" hidden_states,\n",
" *args,\n",
" attention_mask,\n",
" **kwargs):\n",
" \"\"\"\n",
" Custom forward to make sure we only pass relevant arguments to the\n",
" forward pass of the `TransformerLayer`. Also, make sure the output\n",
" format matches the output of the HF's `LlamaDecoderLayer`.\n",
" \"\"\"\n",
" return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),)\n",
"```\n",
"\n",
"2. Before creating a `LlamaForCausalLM`, `replace_decoder` context manager is used to monkey-patch `LlamaDecoderLayer` with `TELlamaDecoderLayer`.\n",
"\n",
"```\n",
"@contextmanager\n",
"def replace_decoder(te_decoder_cls):\n",
" \"\"\"\n",
" Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.\n",
" \"\"\"\n",
" original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer\n",
" transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls\n",
" try:\n",
" yield\n",
" finally:\n",
" transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls\n",
".\n",
".\n",
".\n",
"class TELlamaForCausalLM:\n",
" \"\"\"\n",
" Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`\n",
" class is monkey-patched with `TELlamaDecoderLayer` class before\n",
" initializing the causal LM with `LlamaForCausalLM`.\n",
"\n",
" Args:\n",
" config: LlamaConfig\n",
" \"\"\"\n",
"\n",
" def __new__(cls, config: LlamaConfig):\n",
" with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):\n",
" llama_for_causal_lm = LlamaForCausalLM(config)\n",
" return llama_for_causal_lm\n",
".\n",
".\n",
".\n",
"```\n",
"\n",
"3. A custom `pretrained_from_local` method is added that copies the weights from the checkpoint (which is meant for HF Llama implementation) to the modified `TELlamaForCausalLM` by carefully mapping the weights from the `LlamaDecoderLayer` (HF) to `TransformerLayer` (TE). The method `replace_params` maps and copies apt weights from `LlamaDecoderLayer` to the `TransformerLayer`. Refer to the following diagram for more details.\n",
"\n",
"```\n",
"def replace_params(hf_state_dict, te_state_dict):\n",
" # collect all layer prefixes to update\n",
" all_layer_prefixes = set()\n",
" for param_key in hf_state_dict.keys():\n",
" layer_prefix_pat = 'model.layers.\\d+.'\n",
" m = re.match(layer_prefix_pat, param_key)\n",
" if m is not None:\n",
" all_layer_prefixes.add(m.group())\n",
"\n",
" for layer_prefix in all_layer_prefixes:\n",
" # When loading weights into models with less number of layers, skip the\n",
" # copy if the corresponding layer doesn't exist in TE model\n",
" if layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' in te_state_dict:\n",
" te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]\n",
"\n",
" if layer_prefix + 'self_attention.layernorm_qkv.query_weight' in te_state_dict:\n",
" te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]\n",
"\n",
" if layer_prefix + 'self_attention.layernorm_qkv.key_weight' in te_state_dict:\n",
" te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]\n",
" .\n",
" .\n",
" .\n",
"\n",
" return all_layer_prefixes\n",
"```\n",
"\n",
"The following figure shows how the weights get mapped from the HF's `LlamaDecoderLayer` to TE's `TransformerLayer`.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/weight_swap.svg\">\n",
" <figcaption> Fig 9: Replace `LlamaDecoderLayer` with `TransformerLayer`. </figcaption>\n",
"</figure>\n",
"\n",
"After initializing the modified Llama model this way, the core decoder layers get changed to `TELlamaDecoderLayer` (wrapper around `TransformerLayer`) as shown in the following output:\n",
"```\n",
"ModuleList(\n",
" (0-31): 32 x TELlamaDecoderLayer(\n",
" (self_attention): MultiheadAttention(\n",
" (layernorm_qkv): LayerNormLinear()\n",
" (core_attention): DotProductAttention(\n",
" (flash_attention): FlashAttention()\n",
" (fused_attention): FusedAttention()\n",
" (unfused_attention): UnfusedDotProductAttention(\n",
" (scale_mask_softmax): FusedScaleMaskSoftmax()\n",
" (attention_dropout): Dropout(p=0, inplace=False)\n",
" )\n",
" )\n",
" (proj): Linear()\n",
" )\n",
" (layernorm_mlp): LayerNormMLP()\n",
" )\n",
")\n",
"```\n",
"\n",
"In summary, the model gets changed as follows with a large chunk of the implementation (core decoder layers) coming from Transformer Engine.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"media/model_change.svg\">\n",
" <figcaption> Fig 10: Language model after the HF's `LlamaDecoderLayer`s are replaced with TE's `TransformerLayer`s. </figcaption>\n",
"</figure>\n",
"\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
"\n",
"Let's first run this \"TELlama\" implementation in `BF16` precision.\n",
"</div>"
],
"id": "3db90dff"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
"\n",
"# Init the model and accelerator wrapper\n",
"model = init_te_llama_model(hyperparams)\n",
"accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n",
"\n",
"\n",
"# Finetune the model\n",
"finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"10 finetuning steps complete!\n",
"Average time taken per step: 185 milliseconds\n"
]
}
],
"id": "bdb34b91"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `LlamaDecoderLayer` gives a speedup of **34%** even when using only BF16 precision!\n",
"\n",
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 248 | 1 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 185 | 1.34 |"
],
"id": "0c9fbd65"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n",
"\n",
"Now that most of the HF Llama model implementation (`LlamaDecoderLayer`s) has been swapped with Transformer Engine implementation (`TELlamaDecoderLayer` or `TransformerLayer`), let's see how finetuning in `FP8` precision helps improve performance.\n",
"\n",
"### How to run the model in `FP8` precision\n",
"\n",
"After the substitution, the model can be run in `FP8` precision by the following change over the previous BF16 runs. (For more information, refer the corresponding `wrap_with_accelerator` function in the accompanying `utils.py` file).\n",
"\n",
"```\n",
"# Specify the `FP8RecipeKwargs` (additional argument required to run in `fp8` precision)\n",
"fp8_kwarg_handler = [FP8RecipeKwargs(backend=\"te\")]\n",
"\n",
"# Pass the `FP8RecipeKwargs` to the `Accelerator` init call\n",
"accelerator = Accelerator(\n",
" ...\n",
" kwargs_handlers=fp8_kwarg_handler\n",
")\n",
"```"
],
"id": "98cd8efb"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"fp8\"\n",
"\n",
"\n",
"# Init the model and accelerator wrapper\n",
"model = init_te_llama_model(hyperparams)\n",
"accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n",
"\n",
"\n",
"# Finetune the model\n",
"finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"10 finetuning steps complete!\n",
"Average time taken per step: 160 milliseconds\n"
]
}
],
"id": "772c6f22"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 248 | 1 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 185 | 1.34 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 160 | 1.55 |\n",
"\n",
"\n",
"After turning on FP8 precision, we get even more speedup of **55%** (with Llama 2 7B)!\n",
"\n",
"### Llama 3 performance results\n",
"Running the same tutorial with **Llama 3 8B** yields the following performance numbers:\n",
"\n",
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 270 | 1 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 217 | 1.24 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 185 | 1.46 |\n",
"\n",
"For Llama 3 8B, we get the most speedup of **46%** with FP8 precision!\n",
"\n"
],
"id": "e7cf9c3a"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides a speedup over Hugging Face's native Llama 2 and Llama 3 implementations. This needs careful initialization of the model such that the model weights (which are meant for `LlamaDecoderLayer`) are correctly mapped to their counterparts in TE's `TransformerLayer`. Even with `BF16` precision, `TransformerLayer` provides a speedup over the baseline implementation. With `FP8` precision, the speed up is even more pronounced!"
],
"id": "95d6c42b"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
} }
],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"fp8\"\n",
"\n",
"\n",
"# Init the model and accelerator wrapper\n",
"model = init_te_llama_model(hyperparams)\n",
"accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n",
"\n",
"\n",
"# Finetune the model\n",
"finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)"
]
},
{
"cell_type": "markdown",
"id": "e7cf9c3a",
"metadata": {},
"source": [
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 248 | 1 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 185 | 1.34 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 160 | 1.55 |\n",
"\n",
"\n",
"After turning on FP8 precision, we get even more speedup of **55%** (with Llama 2 7B)!\n",
"\n",
"#### Llama 3 performance results\n",
"Running the same tutorial with **Llama 3 8B** yields the following performance numbers:\n",
"\n",
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 270 | 1 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 217 | 1.24 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 185 | 1.46 |\n",
"\n",
"For Llama 3 8B, we get the most speedup of **46%** with FP8 precision!\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "95d6c42b",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides a speedup over Hugging Face's native Llama 2 and Llama 3 implementations. This needs careful initialization of the model such that the model weights (which are meant for `LlamaDecoderLayer`) are correctly mapped to their counterparts in TE's `TransformerLayer`. Even with `BF16` precision, `TransformerLayer` provides a speedup over the baseline implementation. With `FP8` precision, the speed up is even more pronounced!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}, },
"language_info": { "nbformat": 4,
"codemirror_mode": { "nbformat_minor": 5
"name": "ipython", }
"version": 3 \ No newline at end of file
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
FP8 Blockwise Scaling
===================================
.. warning::
``Float8BlockScaling`` is **currently not supported** in JAX.
FP8 Blockwise Scaling recipe is inspired by the quantization scheme used to train the `DeepSeek-v3 model <https://arxiv.org/abs/2412.19437>`__ –
the first open-source large-scale LLM trained entirely in FP8 precision.
Unlike the previous recipes, it assigns a dedicated scaling factor to each block of elements.
Data Format
--------------------------
The representation of an FP8 tensor element ``x`` in blockwise precision is given by:
.. code-block:: python
x = x_fp8 * s_block
where
* ``x_fp8`` is the FP8 value (E4M3 or E5M2),
* ``s_block`` is a local **FP32** scaling factor shared by a block of elements.
.. raw:: html
:file: img/combined_scaling.svg
*Figure 1. Top: Comparison of standard FP8 scaling (left) using a single scaling factor per tensor versus
FP8 blockwise scaling in 1 dimension (right) using multiple scaling factors, one per block of 128 elements.
Bottom: FP8 blockwise scaling in 2 dimensions where each 128×128 block in the data tensor has a corresponding
scaling factor.*
**FP8 format**
Unlike FP8 Current/Delayed Scaling, E4M3 is used by default for both forward and backward passes.
Tensor-scaled recipes used E5M2 for gradients due to its higher dynamic range,
but with multiple scaling factors per tensor the dynamic range requirement is lowered, so E4M3 is usually sufficient.
The ``fp8_format`` parameter also supports ``HYBRID`` mode (E4M3 for forward, E5M2 for backward).
Pure E5M2 training is not supported.
**Block size**
Block size is 128.
Blocks can be:
* one dimensional – containing 128 consecutive values,
* two dimensional – containing tiles of 128×128 values.
By default:
* activations use 1D scaling (``x_block_scaling_dim=1``),
* weights use 2D scaling (``w_block_scaling_dim=2``),
* gradients use 1D scaling (``grad_block_scaling_dim=1``).
These can be changed in the recipe, but 2D × 2D GEMMs are not supported
– at most one operand can use 2D scaling.
One-dimensional scaling is more granular, but 2D scaling offers two advantages:
* *Performance*: On Hopper, block-scaled GEMMs are software-emulated. GEMMs with mixed
1D/2D scaled tensors have lower overhead than pure 1D scaled GEMMs.
* *Numerical stability*: 2D scaling behaves better when transposed (details in the next section).
There are some assumptions on the dimensions of the tensor (for both 1D and 2D scaling):
* the tensor must have at least 2 dimensions,
* the last dimension must be divisible by 128,
* the product of all dimensions except the last must be divisible by 128.
**Scaling factors**
Scaling factors are stored as 32-bit floating point numbers.
By default, they are constrained to powers of 2 (utilizing the 8 exponent bits of FP32).
On Hopper, this constraint can be relaxed by setting the environment variable ``NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1``.
On Blackwell, only powers of 2 are supported.
Each block's scaling factor is computed through the following steps:
1. Find the maximum absolute value (``amax_block``) across all elements in the block
(128 consecutive values for 1D blocks, or 128×128 values for 2D blocks).
2. Calculate ``s_block = max_fp8 / amax_block``, where ``max_fp8`` is
the maximum representable value in the FP8 format (448 for E4M3, 57344 for E5M2).
3. If the power-of-2 constraint is enabled, round down to the nearest power of 2
by zeroing out the mantissa bits, retaining only the sign and exponent.
4. Multiply each element in the block by ``s_block`` before converting to FP8.
This approach ensures that the largest value in each block fits within the FP8 representable range without overflow.
Handling transposes
------------------------
On Hopper, columnwise tensor access requires data to be transposed in memory.
For 1D scaling, the block direction must align with the access pattern:
* *Rowwise access*: 1 scaling factor per 128 consecutive elements in a row.
* *Columnwise access*: 1 scaling factor per 128 consecutive elements in a row of the transposed tensor,
corresponding to 128 consecutive elements in a column of the original tensor.
For 2D scaling, each 128×128 tile has one scaling factor regardless of access direction.
This is illustrated below:
.. raw:: html
:file: img/transpose_handling.svg
*Figure 2. Quantization directions for original and transposed tensors.*
Note that for 1D scaling, the rowwise and columnwise quantized tensors may be numerically different,
so the gradient computation may be affected. This issue is not present for 2D scaling.
Activations and weights use the rowwise version in the forward pass and the columnwise version in the backward pass.
Experiments have shown that 2D scaling for weights is more helpful for numerical stability than for activations,
so by default 1D scaling is used for activations – as it is more granular – and 2D scaling is used for weights.
Unlike FP8 Current/Delayed Scaling, transposing a 1D quantized tensor is not supported.
Rowwise and columnwise blocks cover different sets of elements, so their scaling factors differ.
Both versions must be quantized separately from the high-precision source.
For 2D scaling, columnwise data can be created from rowwise data by transposing
both the quantized data and the scaling factors. Each 128×128 block covers the same
elements regardless of access direction, so the scaling factors remain valid.
Distributed training
-----------------------
**Scale synchronization**
The blockwise scaled tensor does not need any scale synchronization among the nodes.
This is because each scaling factor is local to its 128 or 128×128 element block,
unlike FP8 Current/Delayed Scaling where a single global scale applies to the entire tensor, even when sharded.
**Quantized all-gather**
FP8 Blockwise Scaling all-gather is supported.
Examples
--------
Here's how to use the FP8 Blockwise Scaling recipe in PyTorch and JAX:
.. note::
Requires SM90 (Hopper) or later.
.. tabs::
.. tab:: PyTorch
.. literalinclude:: pytorch_blockwise_scaling_example.py
:language: python
:start-after: # START_BLOCKWISE_SCALING_EXAMPLE
:end-before: # END_BLOCKWISE_SCALING_EXAMPLE
.. tab:: JAX
``Float8BlockScaling`` is **not currently supported** in JAX.
Supported devices
-----------------
Hopper (SM 9.0)
Blackwell and later (SM >= 10.0) – the recipe is emulated with MXFP8. Note that MXFP8 is the preferred recipe on Blackwell.
Only scaling factors that are powers of 2 are supported.
----
Developer Notes
---------------
This section contains implementation details that may be useful for developers
but are not required for using FP8 Blockwise Scaling in practice.
Swizzle of scaling factors
^^^^^^^^^^^^^^^^^^^^^^^^^^
FP8 Blockwise Scaling supports all-gather of both rowwise and columnwise tensors.
To support that, it implements different data layouts for communication (all-gather)
and computation (GEMM). We refer to the conversion between these formats as *swizzling*.
A tensor of shape ``[A, B]`` can exist in two formats:
**Compact format** (used for all-gather):
The all-gather primitive only supports gathering non-transposed shards into a non-transposed full tensor,
so all tensor components in this layout are stored without transposition.
Moreover, all component tensors are stored without padding.
.. list-table::
:widths: 30 70
:header-rows: 1
* - Component
- Shape
* - rowwise data
- ``[A, B]``
* - columnwise data
- ``[A, B]``
* - rowwise scales
- ``[A, B/128]``
* - columnwise scales
- ``[A/128, B]``
**GEMM-ready format** (used for computation):
Tensors are transposed and padded as required by the GEMM kernel.
.. list-table::
:widths: 30 70
:header-rows: 1
* - Component
- Shape
* - rowwise data
- ``[A, B]``
* - columnwise data
- ``[B, A]`` (transposed)
* - rowwise scales
- ``[B/128, pad4(A)]`` (transposed, padded)
* - columnwise scales
- ``[A/128, pad4(B)]`` (padded)
Swizzling converts from compact to GEMM-ready format. This can be fused with quantization
when no all-gather is needed, or performed separately after all-gather.
.. raw:: html
:file: img/blockwise_swizzle_flow.svg
*Figure 3. FP8 Blockwise Scaling swizzle paths. Top: With all-gather communication – quantization produces
compact format, then swizzle is performed separately after communication. Bottom: Without all-gather –
quantize and swizzle are fused into a single operation, directly producing GEMM-ready format.*
All-gather of columnwise tensors
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
All-gather of columnwise tensors is supported and necessary because:
- columnwise quantized tensors cannot be computed from rowwise quantized ones,
- gathering high-precision tensors is avoided in most cases for performance reasons.
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1150 400" width="100%" style="max-width: 900px;">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
/* Diagram-specific styles */
.input-box { fill: #f3e5f5; stroke: #7b1fa2; stroke-width: 2.5; }
.blockwise-box { fill: #e3f2fd; stroke: #1976d2; stroke-width: 2.5; }
.fp8-tile { fill: #bbdefb; stroke: #1565c0; stroke-width: 1.5; }
.scale-tile { fill: #a5d6a7; stroke: #388e3c; stroke-width: 1.5; }
.scale-swizzled { fill: #ffb74d; stroke: #e65100; stroke-width: 1.5; }
.swizzle-box { fill: #fff3e0; stroke: #f57c00; stroke-width: 2; }
.quantize-box { fill: #ede7f6; stroke: #5e35b1; stroke-width: 2; }
.quantize-fused-box { fill: #d1c4e9; stroke: #5e35b1; stroke-width: 2.5; }
.comm-box { fill: #fff9c4; stroke: #f57f17; stroke-width: 2; }
.gemm-box { fill: #c8e6c9; stroke: #388e3c; stroke-width: 2; }
/* Arrow override */
.arrow { marker-end: url(#arrowhead); stroke: #616161; stroke-width: 1.5; fill: none; }
</style>
<!-- Arrow marker -->
<marker id="arrowhead" markerWidth="6" markerHeight="6" refX="5" refY="2" orient="auto">
<polygon points="0 0, 6 2, 0 4" fill="#616161" />
</marker>
</defs>
<!-- Section 1: With Communication (Separate Swizzle) -->
<g id="with-communication">
<!-- Step 0: Input Tensor -->
<g id="input-fp32-tensor-1">
<text x="80" y="25" class="text" text-anchor="middle" font-weight="600">Input Tensor</text>
<rect x="20" y="40" width="120" height="110" rx="6" class="input-box"/>
<text x="80" y="100" class="text" text-anchor="middle" fill="#fff">FP32/BF16</text>
</g>
<!-- Arrow 0 -->
<path d="M 140 95 L 175 95" class="arrow"/>
<!-- Step 1: Quantize -->
<rect x="175" y="60" width="80" height="70" rx="6" class="quantize-box"/>
<text x="215" y="100" class="text">Quantize</text>
<!-- Arrow 1 -->
<path d="M 255 95 L 290 95" class="arrow"/>
<!-- Step 2: Blockwise Tensor (Compact) -->
<g id="blockwise-tensor-compact">
<text x="375" y="25" class="text" text-anchor="middle" font-weight="600">FP8 (Compact)</text>
<rect x="290" y="40" width="170" height="110" rx="6" class="blockwise-box"/>
<!-- FP32 Scales sub-tile (green) -->
<rect x="305" y="52" width="140" height="32" rx="3" class="scale-tile"/>
<text x="375" y="73" class="text" text-anchor="middle" fill="#fff">FP32 Scales</text>
<!-- FP8 Data sub-tile -->
<rect x="305" y="92" width="140" height="45" rx="3" class="fp8-tile"/>
<text x="375" y="120" class="text" fill="#fff">FP8 Data</text>
</g>
<!-- Arrow 2 -->
<path d="M 460 95 L 495 95" class="arrow"/>
<!-- Step 3: Communication -->
<rect x="495" y="60" width="100" height="70" rx="6" class="comm-box"/>
<text x="545" y="100" class="text">All-Gather</text>
<!-- Arrow 3 -->
<path d="M 595 95 L 630 95" class="arrow"/>
<!-- Step 4: Swizzle -->
<rect x="630" y="60" width="90" height="70" rx="6" class="swizzle-box"/>
<text x="675" y="100" class="text">Swizzle</text>
<!-- Arrow 4 -->
<path d="M 720 95 L 755 95" class="arrow"/>
<!-- Step 5: Blockwise Tensor (GEMM Ready) -->
<g id="swizzled-tensor-1">
<text x="840" y="25" class="text" text-anchor="middle" font-weight="600">FP8 (GEMM Ready)</text>
<rect x="755" y="40" width="170" height="110" rx="6" class="blockwise-box"/>
<!-- Swizzled Scales sub-tile (orange) -->
<rect x="770" y="52" width="140" height="32" rx="3" class="scale-swizzled"/>
<text x="840" y="73" class="text" text-anchor="middle" fill="#fff">Swizzled Scales</text>
<!-- FP8 Data sub-tile -->
<rect x="770" y="92" width="140" height="45" rx="3" class="fp8-tile"/>
<text x="840" y="120" class="text" fill="#fff">FP8 Data</text>
</g>
<!-- Arrow 5 -->
<path d="M 925 95 L 960 95" class="arrow"/>
<!-- Step 6: GEMM -->
<rect x="960" y="60" width="80" height="70" rx="6" class="gemm-box"/>
<text x="1000" y="100" class="text">GEMM</text>
</g>
<!-- Separator Line -->
<line x1="20" y1="185" x2="1050" y2="185" stroke="#bdbdbd" stroke-width="1" stroke-dasharray="8,4"/>
<!-- Section 2: Without Communication (Fused Quantize + Swizzle) -->
<g id="without-communication" transform="translate(0, 170)">
<!-- Step 0: Input Tensor -->
<g id="input-fp32-tensor-2">
<text x="80" y="45" class="text" text-anchor="middle" font-weight="600">Input Tensor</text>
<rect x="20" y="60" width="120" height="110" rx="6" class="input-box"/>
<text x="80" y="120" class="text" text-anchor="middle" fill="#fff">FP32/BF16</text>
</g>
<!-- Arrow 0 -->
<path d="M 140 115 L 190 115" class="arrow"/>
<!-- Step 1: Fused Quantize + Swizzle -->
<rect x="190" y="70" width="120" height="90" rx="6" class="quantize-fused-box"/>
<text x="250" y="105" class="text">Quantize</text>
<text x="250" y="122" class="text">+</text>
<text x="250" y="139" class="text">Swizzle</text>
<!-- Arrow 1 -->
<path d="M 310 115 L 360 115" class="arrow"/>
<!-- Step 2: Blockwise Tensor (GEMM Ready) - directly produced -->
<g id="swizzled-tensor-2">
<text x="455" y="45" class="text" text-anchor="middle" font-weight="600">FP8 (GEMM Ready)</text>
<rect x="360" y="60" width="190" height="110" rx="6" class="blockwise-box"/>
<!-- Swizzled Scales sub-tile (orange) -->
<rect x="378" y="72" width="155" height="32" rx="3" class="scale-swizzled"/>
<text x="455" y="93" class="text" text-anchor="middle" fill="#fff">Swizzled Scales</text>
<!-- FP8 Data sub-tile -->
<rect x="378" y="112" width="155" height="45" rx="3" class="fp8-tile"/>
<text x="455" y="140" class="text" fill="#fff">FP8 Data</text>
</g>
<!-- Arrow 2 -->
<path d="M 550 115 L 600 115" class="arrow"/>
<!-- Step 3: GEMM -->
<rect x="600" y="80" width="80" height="70" rx="6" class="gemm-box"/>
<text x="640" y="120" class="text">GEMM</text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 55 900 715">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 18px sans-serif; fill: #333; text-anchor: middle; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
/* Tensor colors */
.fp8-tensor { fill: #87CEEB; stroke: #444; stroke-width: 2; }
.fp8-block { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.fp8-block-alt { fill: #5F9FCC; stroke: #555; stroke-width: 1.5; }
/* Scaling factor colors */
.scale-factor { fill: #FFA500; stroke: #444; stroke-width: 2; }
.grid-line { stroke: #444; stroke-width: 2; }
.boundary-line { stroke: #444; stroke-width: 2; }
</style>
</defs>
<!-- FIRST IMAGE: Standard vs Blockwise Scaling -->
<!-- LEFT SIDE: Standard FP8 Scaling -->
<g id="standard-scaling">
<text x="225" y="85" class="title">Delayed/Current FP8 Scaling</text>
<text x="225" y="108" class="label">(Single scaling factor per tensor)</text>
<!-- FP8 Tensor - solid blue with white cross -->
<g id="left-tensor">
<!-- Solid blue background -->
<rect x="105" y="140" width="240" height="120" class="fp8-tensor"/>
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="225.0" y="140.0" width="40" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="105.0" y="190.0" width="240" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Three dots in VERTICAL white bar -->
<text x="245" y="167.5" class="dots-text"></text>
<text x="245" y="242.5" class="dots-text"></text>
<!-- Three dots in HORIZONTAL white bar -->
<text x="165" y="205" class="dots-text"></text>
<text x="305" y="205" class="dots-text"></text>
<!-- ONE diagonal dot at intersection -->
<text x="245" y="205" class="dots-text" transform="rotate(45 245 205)"></text>
<!-- Main outline -->
<rect x="105.0" y="140.0" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Single scaling factor - one 10x10 square -->
<rect x="220" y="285" width="10" height="10" class="scale-factor" stroke="#444" stroke-width="1"/>
<text x="225" y="315" class="small-text" text-anchor="middle">1 scaling factor</text>
</g>
<!-- RIGHT SIDE: FP8 Blockwise Scaling -->
<g id="blockwise-scaling">
<text x="675" y="85" class="title">Blockwise FP8 Scaling – 1 dimension</text>
<text x="675" y="108" class="label">(One scaling factor per 128 elements)</text>
<!-- FP8 Tensor split into many small blocks (40×10) - EXACT coordinates from Python script -->
<g id="tensor-blocks">
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="675.0" y="140.0" width="40" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="555.0" y="190.0" width="240" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Blocks ONLY where they don't overlap with white cross (from Python script) -->
<rect x="555" y="140" width="40" height="10" class="fp8-block"/>
<rect x="595" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="635" y="140" width="40" height="10" class="fp8-block"/>
<rect x="715" y="140" width="40" height="10" class="fp8-block"/>
<rect x="755" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="555" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="595" y="150" width="40" height="10" class="fp8-block"/>
<rect x="635" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="715" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="755" y="150" width="40" height="10" class="fp8-block"/>
<rect x="555" y="160" width="40" height="10" class="fp8-block"/>
<rect x="595" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="635" y="160" width="40" height="10" class="fp8-block"/>
<rect x="715" y="160" width="40" height="10" class="fp8-block"/>
<rect x="755" y="160" width="40" height="10" class="fp8-block-alt"/>
<rect x="555" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="595" y="170" width="40" height="10" class="fp8-block"/>
<rect x="635" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="715" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="755" y="170" width="40" height="10" class="fp8-block"/>
<rect x="555" y="180" width="40" height="10" class="fp8-block"/>
<rect x="595" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="635" y="180" width="40" height="10" class="fp8-block"/>
<rect x="715" y="180" width="40" height="10" class="fp8-block"/>
<rect x="755" y="180" width="40" height="10" class="fp8-block-alt"/>
<!-- Three dots in VERTICAL white bar -->
<text x="695" y="167.5" class="dots-text"></text>
<text x="695" y="242.5" class="dots-text"></text>
<!-- Three dots in HORIZONTAL white bar -->
<text x="615" y="205" class="dots-text"></text>
<text x="755" y="205" class="dots-text"></text>
<!-- ONE diagonal dot at intersection -->
<text x="695" y="205" class="dots-text" transform="rotate(45 695 205)"></text>
<!-- Bottom rows (y >= 220 after horizontal white bar) -->
<rect x="555" y="220" width="40" height="10" class="fp8-block"/>
<rect x="595" y="220" width="40" height="10" class="fp8-block-alt"/>
<rect x="635" y="220" width="40" height="10" class="fp8-block"/>
<rect x="715" y="220" width="40" height="10" class="fp8-block"/>
<rect x="755" y="220" width="40" height="10" class="fp8-block-alt"/>
<rect x="555" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="595" y="230" width="40" height="10" class="fp8-block"/>
<rect x="635" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="715" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="755" y="230" width="40" height="10" class="fp8-block"/>
<rect x="555" y="240" width="40" height="10" class="fp8-block"/>
<rect x="595" y="240" width="40" height="10" class="fp8-block-alt"/>
<rect x="635" y="240" width="40" height="10" class="fp8-block"/>
<rect x="715" y="240" width="40" height="10" class="fp8-block"/>
<rect x="755" y="240" width="40" height="10" class="fp8-block-alt"/>
<rect x="555" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="595" y="250" width="40" height="10" class="fp8-block"/>
<rect x="635" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="715" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="755" y="250" width="40" height="10" class="fp8-block"/>
<!-- Main outline -->
<rect x="555.0" y="140.0" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- Scaling factors tensor - 3+2 columns of 10px squares -->
<g id="scale-factors">
<!-- Orange background -->
<rect x="640" y="285" width="70" height="120" fill="#FFA500"/>
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="670" y="285" width="20" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="640" y="335" width="70" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Grid lines showing 10x10 squares (3 left + 2 right columns) -->
<!-- Vertical lines every 10px (skipping white space) -->
<!-- Left 3 columns (640-670) -->
<line x1="650" y1="285" x2="650" y2="335" class="grid-line" stroke-width="1"/>
<line x1="660" y1="285" x2="660" y2="335" class="grid-line" stroke-width="1"/>
<line x1="670" y1="285" x2="670" y2="335" class="grid-line" stroke-width="1"/>
<!-- Right 2 columns (690-710) -->
<line x1="690" y1="285" x2="690" y2="335" class="grid-line" stroke-width="1"/>
<line x1="700" y1="285" x2="700" y2="335" class="grid-line" stroke-width="1"/>
<line x1="710" y1="285" x2="710" y2="335" class="grid-line" stroke-width="1"/>
<!-- Bottom sections -->
<line x1="650" y1="365" x2="650" y2="405" class="grid-line" stroke-width="1"/>
<line x1="660" y1="365" x2="660" y2="405" class="grid-line" stroke-width="1"/>
<line x1="670" y1="365" x2="670" y2="405" class="grid-line" stroke-width="1"/>
<line x1="690" y1="365" x2="690" y2="405" class="grid-line" stroke-width="1"/>
<line x1="700" y1="365" x2="700" y2="405" class="grid-line" stroke-width="1"/>
<line x1="710" y1="365" x2="710" y2="405" class="grid-line" stroke-width="1"/>
<!-- Horizontal lines every 10px -->
<line x1="640" y1="295" x2="670" y2="295" class="grid-line" stroke-width="1"/>
<line x1="690" y1="295" x2="710" y2="295" class="grid-line" stroke-width="1"/>
<line x1="640" y1="305" x2="670" y2="305" class="grid-line" stroke-width="1"/>
<line x1="690" y1="305" x2="710" y2="305" class="grid-line" stroke-width="1"/>
<line x1="640" y1="315" x2="670" y2="315" class="grid-line" stroke-width="1"/>
<line x1="690" y1="315" x2="710" y2="315" class="grid-line" stroke-width="1"/>
<line x1="640" y1="325" x2="670" y2="325" class="grid-line" stroke-width="1"/>
<line x1="690" y1="325" x2="710" y2="325" class="grid-line" stroke-width="1"/>
<!-- Top bottom boundaries -->
<line x1="640" y1="335" x2="670" y2="335" class="grid-line" stroke-width="1"/>
<line x1="690" y1="335" x2="710" y2="335" class="grid-line" stroke-width="1"/>
<line x1="640" y1="365" x2="670" y2="365" class="grid-line" stroke-width="1"/>
<line x1="690" y1="365" x2="710" y2="365" class="grid-line" stroke-width="1"/>
<line x1="640" y1="375" x2="670" y2="375" class="grid-line" stroke-width="1"/>
<line x1="690" y1="375" x2="710" y2="375" class="grid-line" stroke-width="1"/>
<line x1="640" y1="385" x2="670" y2="385" class="grid-line" stroke-width="1"/>
<line x1="690" y1="385" x2="710" y2="385" class="grid-line" stroke-width="1"/>
<line x1="640" y1="395" x2="670" y2="395" class="grid-line" stroke-width="1"/>
<line x1="690" y1="395" x2="710" y2="395" class="grid-line" stroke-width="1"/>
<!-- Bottom boundaries -->
<line x1="640" y1="405" x2="670" y2="405" class="grid-line" stroke-width="1"/>
<line x1="690" y1="405" x2="710" y2="405" class="grid-line" stroke-width="1"/>
<!-- Main outline -->
<rect x="640" y="285" width="70" height="120" fill="none" stroke="#444" stroke-width="2"/>
<!-- Three dots -->
<text x="680" y="312.5" class="dots-text" style="font-size: 14px;"></text>
<text x="680" y="387.5" class="dots-text" style="font-size: 14px;"></text>
<text x="655" y="350" class="dots-text" style="font-size: 14px;"></text>
<text x="700" y="350" class="dots-text" style="font-size: 14px;"></text>
<text x="680" y="350" class="dots-text" style="font-size: 14px;" transform="rotate(45 680 350)"></text>
</g>
<text x="675" y="430" class="small-text" text-anchor="middle">Scaling factors (one per block)</text>
</g>
<!-- SECOND IMAGE: 2D Blockwise Scaling -->
<!-- Main Title -->
<text x="450" y="470" class="title">Blockwise FP8 Scaling – 2 dimensions</text>
<text x="450" y="495" class="label">(One scaling factor per 128x128 block of elements)</text>
<!-- TOP: DATA TENSOR (20x20 blocks, with 3 extra columns on right) -->
<g id="data-tensor">
<!-- Background for entire tensor -->
<rect x="390" y="525" width="180" height="120" class="fp8-tensor"/>
<!-- White space for gaps (cross pattern) -->
<rect x="450" y="525" width="20" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="390" y="585" width="180" height="20" fill="#FFFFFF" stroke="none"/>
<!-- Grid Lines (every 20px) -->
<!-- Vertical Lines Left (x=410, 430) -->
<line x1="410" y1="525" x2="410" y2="585" class="grid-line" stroke-width="1"/>
<line x1="430" y1="525" x2="430" y2="585" class="grid-line" stroke-width="1"/>
<line x1="410" y1="605" x2="410" y2="645" class="grid-line" stroke-width="1"/>
<line x1="430" y1="605" x2="430" y2="645" class="grid-line" stroke-width="1"/>
<!-- Vertical Lines Right (x=490, 510, 530, 550) -->
<line x1="490" y1="525" x2="490" y2="585" class="grid-line" stroke-width="1"/>
<line x1="490" y1="605" x2="490" y2="645" class="grid-line" stroke-width="1"/>
<line x1="510" y1="525" x2="510" y2="585" class="grid-line" stroke-width="1"/>
<line x1="510" y1="605" x2="510" y2="645" class="grid-line" stroke-width="1"/>
<line x1="530" y1="525" x2="530" y2="585" class="grid-line" stroke-width="1"/>
<line x1="530" y1="605" x2="530" y2="645" class="grid-line" stroke-width="1"/>
<line x1="550" y1="525" x2="550" y2="585" class="grid-line" stroke-width="1"/>
<line x1="550" y1="605" x2="550" y2="645" class="grid-line" stroke-width="1"/>
<!-- Horizontal Lines Top (y=545, 565) -->
<line x1="390" y1="545" x2="450" y2="545" class="grid-line" stroke-width="1"/>
<line x1="470" y1="545" x2="570" y2="545" class="grid-line" stroke-width="1"/>
<line x1="390" y1="565" x2="450" y2="565" class="grid-line" stroke-width="1"/>
<line x1="470" y1="565" x2="570" y2="565" class="grid-line" stroke-width="1"/>
<!-- Horizontal Lines Bottom (y=625) -->
<line x1="390" y1="625" x2="450" y2="625" class="grid-line" stroke-width="1"/>
<line x1="470" y1="625" x2="570" y2="625" class="grid-line" stroke-width="1"/>
<!-- Dots / Ellipses -->
<!-- Horizontal dots in gap -->
<text x="460" y="552" class="dots-text" style="font-size: 14px;"></text>
<text x="460" y="632" class="dots-text" style="font-size: 14px;"></text>
<!-- Vertical dots in gap -->
<text x="420" y="597" class="dots-text" style="font-size: 14px;"></text>
<text x="540" y="597" class="dots-text" style="font-size: 14px;"></text>
<!-- Diagonal dot -->
<text x="460" y="597" class="dots-text" style="font-size: 14px;" transform="rotate(45 460 597)"></text>
<!-- Boundaries around white spaces (excluding center intersection) -->
<!-- Vertical boundaries - broken at horizontal white space -->
<line x1="450" y1="525" x2="450" y2="585" class="boundary-line"/>
<line x1="450" y1="605" x2="450" y2="645" class="boundary-line"/>
<line x1="470" y1="525" x2="470" y2="585" class="boundary-line"/>
<line x1="470" y1="605" x2="470" y2="645" class="boundary-line"/>
<!-- Horizontal boundaries - broken at vertical white space -->
<line x1="390" y1="585" x2="450" y2="585" class="boundary-line"/>
<line x1="470" y1="585" x2="570" y2="585" class="boundary-line"/>
<line x1="390" y1="605" x2="450" y2="605" class="boundary-line"/>
<line x1="470" y1="605" x2="570" y2="605" class="boundary-line"/>
<!-- Main outline -->
<rect x="390" y="525" width="180" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
<!-- BOTTOM: SCALING FACTORS (10x10 blocks, with 3 extra columns on right) -->
<g id="scaling-factors-2d">
<!-- Background for entire scaling tensor -->
<rect x="420" y="675" width="90" height="60" class="scale-factor"/>
<!-- White space for gaps (cross pattern) -->
<rect x="450" y="675" width="10" height="60" fill="#FFFFFF" stroke="none"/>
<rect x="420" y="705" width="90" height="10" fill="#FFFFFF" stroke="none"/>
<!-- Grid Lines (every 10px) -->
<!-- Vertical Left -->
<line x1="430" y1="675" x2="430" y2="705" class="grid-line" stroke-width="1"/>
<line x1="440" y1="675" x2="440" y2="705" class="grid-line" stroke-width="1"/>
<line x1="430" y1="715" x2="430" y2="735" class="grid-line" stroke-width="1"/>
<line x1="440" y1="715" x2="440" y2="735" class="grid-line" stroke-width="1"/>
<!-- Vertical Right -->
<line x1="470" y1="675" x2="470" y2="705" class="grid-line" stroke-width="1"/>
<line x1="470" y1="715" x2="470" y2="735" class="grid-line" stroke-width="1"/>
<line x1="480" y1="675" x2="480" y2="705" class="grid-line" stroke-width="1"/>
<line x1="480" y1="715" x2="480" y2="735" class="grid-line" stroke-width="1"/>
<line x1="490" y1="675" x2="490" y2="705" class="grid-line" stroke-width="1"/>
<line x1="490" y1="715" x2="490" y2="735" class="grid-line" stroke-width="1"/>
<line x1="500" y1="675" x2="500" y2="705" class="grid-line" stroke-width="1"/>
<line x1="500" y1="715" x2="500" y2="735" class="grid-line" stroke-width="1"/>
<!-- Horizontal Top -->
<line x1="420" y1="685" x2="450" y2="685" class="grid-line" stroke-width="1"/>
<line x1="460" y1="685" x2="510" y2="685" class="grid-line" stroke-width="1"/>
<line x1="420" y1="695" x2="450" y2="695" class="grid-line" stroke-width="1"/>
<line x1="460" y1="695" x2="510" y2="695" class="grid-line" stroke-width="1"/>
<!-- Horizontal Bottom -->
<line x1="420" y1="725" x2="450" y2="725" class="grid-line" stroke-width="1"/>
<line x1="460" y1="725" x2="510" y2="725" class="grid-line" stroke-width="1"/>
<!-- Dots -->
<text x="455" y="692" class="dots-text" style="font-size: 12px;"></text>
<text x="455" y="727" class="dots-text" style="font-size: 12px;"></text>
<text x="435" y="711" class="dots-text" style="font-size: 12px;"></text>
<text x="490" y="711" class="dots-text" style="font-size: 12px;"></text>
<text x="455" y="711" class="dots-text" style="font-size: 12px;" transform="rotate(45 455 711)"></text>
<!-- Boundaries around white spaces (excluding center intersection) -->
<!-- Vertical boundaries - broken at horizontal white space -->
<line x1="450" y1="675" x2="450" y2="705" class="boundary-line"/>
<line x1="450" y1="715" x2="450" y2="735" class="boundary-line"/>
<line x1="460" y1="675" x2="460" y2="705" class="boundary-line"/>
<line x1="460" y1="715" x2="460" y2="735" class="boundary-line"/>
<!-- Horizontal boundaries - broken at vertical white space -->
<line x1="420" y1="705" x2="450" y2="705" class="boundary-line"/>
<line x1="460" y1="705" x2="510" y2="705" class="boundary-line"/>
<line x1="420" y1="715" x2="450" y2="715" class="boundary-line"/>
<line x1="460" y1="715" x2="510" y2="715" class="boundary-line"/>
<!-- Main outline -->
<rect x="420" y="675" width="90" height="60" fill="none" stroke="#444" stroke-width="2"/>
<text x="465" y="755" class="small-text" text-anchor="middle">Scaling factors (1 per 2D block)</text>
</g>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 900 640">
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.title { font: bold 16px sans-serif; fill: #333; text-anchor: middle; }
.label { font: 14px sans-serif; fill: #333; text-anchor: middle; }
.small-text { font: 12px sans-serif; fill: #555; }
.dots-text { font: bold 24px sans-serif; fill: #333; text-anchor: middle; }
/* Tensor colors */
.fp8-block { fill: #87CEEB; stroke: #555; stroke-width: 1.5; }
.fp8-block-alt { fill: #5F9FCC; stroke: #555; stroke-width: 1.5; }
</style>
</defs>
<!-- Section title for 1D -->
<text x="450" y="25" class="title" style="font-size: 18px; font-weight: bold;">1D Blockwise Scaling</text>
<!-- LEFT SIDE: Original 1D Blockwise (Rowwise Quantization) -->
<g id="rowwise-quantization">
<text x="225" y="50" class="title">Rowwise Quantization</text>
<!-- FP8 Tensor with horizontal stripes -->
<g id="left-tensor">
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="225.0" y="100.0" width="40" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="105.0" y="150.0" width="240" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Horizontal blocks (40×10 each) - rows of alternating colors -->
<!-- Top section (before horizontal gap) -->
<rect x="105" y="100" width="40" height="10" class="fp8-block"/>
<rect x="145" y="100" width="40" height="10" class="fp8-block-alt"/>
<rect x="185" y="100" width="40" height="10" class="fp8-block"/>
<rect x="265" y="100" width="40" height="10" class="fp8-block"/>
<rect x="305" y="100" width="40" height="10" class="fp8-block-alt"/>
<rect x="105" y="110" width="40" height="10" class="fp8-block-alt"/>
<rect x="145" y="110" width="40" height="10" class="fp8-block"/>
<rect x="185" y="110" width="40" height="10" class="fp8-block-alt"/>
<rect x="265" y="110" width="40" height="10" class="fp8-block-alt"/>
<rect x="305" y="110" width="40" height="10" class="fp8-block"/>
<rect x="105" y="120" width="40" height="10" class="fp8-block"/>
<rect x="145" y="120" width="40" height="10" class="fp8-block-alt"/>
<rect x="185" y="120" width="40" height="10" class="fp8-block"/>
<rect x="265" y="120" width="40" height="10" class="fp8-block"/>
<rect x="305" y="120" width="40" height="10" class="fp8-block-alt"/>
<rect x="105" y="130" width="40" height="10" class="fp8-block-alt"/>
<rect x="145" y="130" width="40" height="10" class="fp8-block"/>
<rect x="185" y="130" width="40" height="10" class="fp8-block-alt"/>
<rect x="265" y="130" width="40" height="10" class="fp8-block-alt"/>
<rect x="305" y="130" width="40" height="10" class="fp8-block"/>
<rect x="105" y="140" width="40" height="10" class="fp8-block"/>
<rect x="145" y="140" width="40" height="10" class="fp8-block-alt"/>
<rect x="185" y="140" width="40" height="10" class="fp8-block"/>
<rect x="265" y="140" width="40" height="10" class="fp8-block"/>
<rect x="305" y="140" width="40" height="10" class="fp8-block-alt"/>
<!-- Three dots in VERTICAL white bar -->
<text x="245" y="127.5" class="dots-text"></text>
<text x="245" y="202.5" class="dots-text"></text>
<!-- Three dots in HORIZONTAL white bar -->
<text x="165" y="165" class="dots-text"></text>
<text x="305" y="165" class="dots-text"></text>
<!-- ONE diagonal dot at intersection -->
<text x="245" y="165" class="dots-text" transform="rotate(45 245 165)"></text>
<!-- Bottom section (after horizontal gap) -->
<rect x="105" y="180" width="40" height="10" class="fp8-block"/>
<rect x="145" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="185" y="180" width="40" height="10" class="fp8-block"/>
<rect x="265" y="180" width="40" height="10" class="fp8-block"/>
<rect x="305" y="180" width="40" height="10" class="fp8-block-alt"/>
<rect x="105" y="190" width="40" height="10" class="fp8-block-alt"/>
<rect x="145" y="190" width="40" height="10" class="fp8-block"/>
<rect x="185" y="190" width="40" height="10" class="fp8-block-alt"/>
<rect x="265" y="190" width="40" height="10" class="fp8-block-alt"/>
<rect x="305" y="190" width="40" height="10" class="fp8-block"/>
<rect x="105" y="200" width="40" height="10" class="fp8-block"/>
<rect x="145" y="200" width="40" height="10" class="fp8-block-alt"/>
<rect x="185" y="200" width="40" height="10" class="fp8-block"/>
<rect x="265" y="200" width="40" height="10" class="fp8-block"/>
<rect x="305" y="200" width="40" height="10" class="fp8-block-alt"/>
<rect x="105" y="210" width="40" height="10" class="fp8-block-alt"/>
<rect x="145" y="210" width="40" height="10" class="fp8-block"/>
<rect x="185" y="210" width="40" height="10" class="fp8-block-alt"/>
<rect x="265" y="210" width="40" height="10" class="fp8-block-alt"/>
<rect x="305" y="210" width="40" height="10" class="fp8-block"/>
<!-- Main outline -->
<rect x="105.0" y="100.0" width="240" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
</g>
<!-- RIGHT SIDE: Transposed (Columnwise Quantization) -->
<g id="columnwise-quantization">
<text x="625" y="50" class="title">Columnwise Quantization</text>
<!-- FP8 Tensor - transposed shape (120 wide × 240 tall) with HORIZONTAL stripes -->
<g id="right-tensor">
<!-- White backgrounds for dots areas - cross pattern -->
<rect x="645.0" y="100.0" width="40" height="240" fill="#FFFFFF" stroke="none"/>
<rect x="565.0" y="260.0" width="120" height="30" fill="#FFFFFF" stroke="none"/>
<!-- Horizontal stripes 40×10 (same as rowwise) -->
<!-- Top section (before horizontal gap) - 16 rows of 10px each = 160px -->
<rect x="565" y="100" width="40" height="10" class="fp8-block"/>
<rect x="605" y="100" width="40" height="10" class="fp8-block"/>
<rect x="565" y="110" width="40" height="10" class="fp8-block-alt"/>
<rect x="605" y="110" width="40" height="10" class="fp8-block-alt"/>
<rect x="565" y="120" width="40" height="10" class="fp8-block"/>
<rect x="605" y="120" width="40" height="10" class="fp8-block"/>
<rect x="565" y="130" width="40" height="10" class="fp8-block-alt"/>
<rect x="605" y="130" width="40" height="10" class="fp8-block-alt"/>
<rect x="565" y="140" width="40" height="10" class="fp8-block"/>
<rect x="605" y="140" width="40" height="10" class="fp8-block"/>
<rect x="565" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="605" y="150" width="40" height="10" class="fp8-block-alt"/>
<rect x="565" y="160" width="40" height="10" class="fp8-block"/>
<rect x="605" y="160" width="40" height="10" class="fp8-block"/>
<rect x="565" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="605" y="170" width="40" height="10" class="fp8-block-alt"/>
<rect x="565" y="180" width="40" height="10" class="fp8-block"/>
<rect x="605" y="180" width="40" height="10" class="fp8-block"/>
<rect x="565" y="190" width="40" height="10" class="fp8-block-alt"/>
<rect x="605" y="190" width="40" height="10" class="fp8-block-alt"/>
<rect x="565" y="200" width="40" height="10" class="fp8-block"/>
<rect x="605" y="200" width="40" height="10" class="fp8-block"/>
<rect x="565" y="210" width="40" height="10" class="fp8-block-alt"/>
<rect x="605" y="210" width="40" height="10" class="fp8-block-alt"/>
<rect x="565" y="220" width="40" height="10" class="fp8-block"/>
<rect x="605" y="220" width="40" height="10" class="fp8-block"/>
<rect x="565" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="605" y="230" width="40" height="10" class="fp8-block-alt"/>
<rect x="565" y="240" width="40" height="10" class="fp8-block"/>
<rect x="605" y="240" width="40" height="10" class="fp8-block"/>
<rect x="565" y="250" width="40" height="10" class="fp8-block-alt"/>
<rect x="605" y="250" width="40" height="10" class="fp8-block-alt"/>
<!-- Three dots in VERTICAL white bar -->
<text x="665" y="200" class="dots-text"></text>
<text x="665" y="330" class="dots-text"></text>
<!-- Three dots in HORIZONTAL white bar -->
<text x="605" y="275" class="dots-text"></text>
<!-- ONE diagonal dot at intersection -->
<text x="665" y="275" class="dots-text" transform="rotate(45 665 275)"></text>
<!-- Bottom section (after horizontal gap) - 5 rows of 10px each = 50px -->
<rect x="565" y="290" width="40" height="10" class="fp8-block"/>
<rect x="605" y="290" width="40" height="10" class="fp8-block"/>
<rect x="565" y="300" width="40" height="10" class="fp8-block-alt"/>
<rect x="605" y="300" width="40" height="10" class="fp8-block-alt"/>
<rect x="565" y="310" width="40" height="10" class="fp8-block"/>
<rect x="605" y="310" width="40" height="10" class="fp8-block"/>
<rect x="565" y="320" width="40" height="10" class="fp8-block-alt"/>
<rect x="605" y="320" width="40" height="10" class="fp8-block-alt"/>
<rect x="565" y="330" width="40" height="10" class="fp8-block"/>
<rect x="605" y="330" width="40" height="10" class="fp8-block"/>
<!-- Main outline -->
<rect x="565.0" y="100.0" width="120" height="240" fill="none" stroke="#444" stroke-width="2"/>
</g>
</g>
<!-- SECTION 2: 2D Blockwise Scaling -->
<!-- Section title for 2D -->
<text x="450" y="380" class="title" style="font-size: 18px; font-weight: bold;">2D Blockwise Scaling</text>
<!-- LEFT SIDE: Original 2D Blockwise (copied from combined_scaling.svg) -->
<g id="2d-original">
<text x="280" y="405" class="title">Rowwise Quantization</text>
<!-- TOP: DATA TENSOR (20x20 blocks, with 3 extra columns on right) -->
<g id="data-tensor-left">
<!-- Background for entire tensor -->
<rect x="190" y="445" width="180" height="120" fill="#87CEEB" stroke="#444" stroke-width="2"/>
<!-- White space for gaps (cross pattern) -->
<rect x="250" y="445" width="20" height="120" fill="#FFFFFF" stroke="none"/>
<rect x="190" y="505" width="180" height="20" fill="#FFFFFF" stroke="none"/>
<!-- Grid Lines (every 20px) -->
<!-- Vertical Lines Left (x=210, 230) -->
<line x1="210" y1="445" x2="210" y2="505" stroke="#444" stroke-width="1"/>
<line x1="230" y1="445" x2="230" y2="505" stroke="#444" stroke-width="1"/>
<line x1="210" y1="525" x2="210" y2="565" stroke="#444" stroke-width="1"/>
<line x1="230" y1="525" x2="230" y2="565" stroke="#444" stroke-width="1"/>
<!-- Vertical Lines Right (x=290, 310, 330, 350) -->
<line x1="290" y1="445" x2="290" y2="505" stroke="#444" stroke-width="1"/>
<line x1="290" y1="525" x2="290" y2="565" stroke="#444" stroke-width="1"/>
<line x1="310" y1="445" x2="310" y2="505" stroke="#444" stroke-width="1"/>
<line x1="310" y1="525" x2="310" y2="565" stroke="#444" stroke-width="1"/>
<line x1="330" y1="445" x2="330" y2="505" stroke="#444" stroke-width="1"/>
<line x1="330" y1="525" x2="330" y2="565" stroke="#444" stroke-width="1"/>
<line x1="350" y1="445" x2="350" y2="505" stroke="#444" stroke-width="1"/>
<line x1="350" y1="525" x2="350" y2="565" stroke="#444" stroke-width="1"/>
<!-- Horizontal Lines Top (y=465, 485) -->
<line x1="190" y1="465" x2="250" y2="465" stroke="#444" stroke-width="1"/>
<line x1="270" y1="465" x2="370" y2="465" stroke="#444" stroke-width="1"/>
<line x1="190" y1="485" x2="250" y2="485" stroke="#444" stroke-width="1"/>
<line x1="270" y1="485" x2="370" y2="485" stroke="#444" stroke-width="1"/>
<!-- Horizontal Lines Bottom (y=545) -->
<line x1="190" y1="545" x2="250" y2="545" stroke="#444" stroke-width="1"/>
<line x1="270" y1="545" x2="370" y2="545" stroke="#444" stroke-width="1"/>
<!-- Dots / Ellipses -->
<!-- Horizontal dots in gap -->
<text x="260" y="472" class="dots-text" style="font-size: 14px;"></text>
<text x="260" y="552" class="dots-text" style="font-size: 14px;"></text>
<!-- Vertical dots in gap -->
<text x="220" y="517" class="dots-text" style="font-size: 14px;"></text>
<text x="340" y="517" class="dots-text" style="font-size: 14px;"></text>
<!-- Diagonal dot -->
<text x="260" y="517" class="dots-text" style="font-size: 14px;" transform="rotate(45 260 517)"></text>
<!-- Boundaries around white spaces (excluding center intersection) -->
<!-- Vertical boundaries - broken at horizontal white space -->
<line x1="250" y1="445" x2="250" y2="505" stroke="#444" stroke-width="2"/>
<line x1="250" y1="525" x2="250" y2="565" stroke="#444" stroke-width="2"/>
<line x1="270" y1="445" x2="270" y2="505" stroke="#444" stroke-width="2"/>
<line x1="270" y1="525" x2="270" y2="565" stroke="#444" stroke-width="2"/>
<!-- Horizontal boundaries - broken at vertical white space -->
<line x1="190" y1="505" x2="250" y2="505" stroke="#444" stroke-width="2"/>
<line x1="270" y1="505" x2="370" y2="505" stroke="#444" stroke-width="2"/>
<line x1="190" y1="525" x2="250" y2="525" stroke="#444" stroke-width="2"/>
<line x1="270" y1="525" x2="370" y2="525" stroke="#444" stroke-width="2"/>
<!-- Main outline -->
<rect x="190" y="445" width="180" height="120" fill="none" stroke="#444" stroke-width="2"/>
</g>
</g>
<!-- RIGHT SIDE: Transposed 2D Blockwise -->
<g id="2d-transposed">
<text x="605" y="405" class="title">Columnwise Quantization</text>
<!-- DATA TENSOR TRANSPOSED (120x180 instead of 180x120) -->
<g id="data-tensor-right">
<!-- Background for entire tensor -->
<rect x="545" y="435" width="120" height="180" fill="#87CEEB" stroke="#444" stroke-width="2"/>
<!-- White space for gaps (cross pattern) - TRANSPOSED -->
<!-- Original: X structure (180): 60 + 20 + 100 → Y structure (180): 60 + 20 + 100 -->
<!-- Original: Y structure (120): 60 + 20 + 40 → X structure (120): 60 + 20 + 40 -->
<rect x="545" y="495" width="120" height="20" fill="#FFFFFF" stroke="none"/>
<rect x="605" y="435" width="20" height="180" fill="#FFFFFF" stroke="none"/>
<!-- Grid Lines (every 20px) - TRANSPOSED -->
<!-- Original vertical lines at x=210, 230 become horizontal at y=455, 475 -->
<line x1="545" y1="455" x2="605" y2="455" stroke="#444" stroke-width="1"/>
<line x1="625" y1="455" x2="665" y2="455" stroke="#444" stroke-width="1"/>
<line x1="545" y1="475" x2="605" y2="475" stroke="#444" stroke-width="1"/>
<line x1="625" y1="475" x2="665" y2="475" stroke="#444" stroke-width="1"/>
<!-- Original vertical lines at x=290, 310, 330, 350 become horizontal at y=535, 555, 575, 595 -->
<line x1="545" y1="535" x2="605" y2="535" stroke="#444" stroke-width="1"/>
<line x1="625" y1="535" x2="665" y2="535" stroke="#444" stroke-width="1"/>
<line x1="545" y1="555" x2="605" y2="555" stroke="#444" stroke-width="1"/>
<line x1="625" y1="555" x2="665" y2="555" stroke="#444" stroke-width="1"/>
<line x1="545" y1="575" x2="605" y2="575" stroke="#444" stroke-width="1"/>
<line x1="625" y1="575" x2="665" y2="575" stroke="#444" stroke-width="1"/>
<line x1="545" y1="595" x2="605" y2="595" stroke="#444" stroke-width="1"/>
<line x1="625" y1="595" x2="665" y2="595" stroke="#444" stroke-width="1"/>
<!-- Original horizontal lines at y=465, 485 become vertical at x=565, 585 -->
<line x1="565" y1="435" x2="565" y2="495" stroke="#444" stroke-width="1"/>
<line x1="565" y1="515" x2="565" y2="615" stroke="#444" stroke-width="1"/>
<line x1="585" y1="435" x2="585" y2="495" stroke="#444" stroke-width="1"/>
<line x1="585" y1="515" x2="585" y2="615" stroke="#444" stroke-width="1"/>
<!-- Original horizontal line at y=545 becomes vertical at x=605, 625, 645 -->
<line x1="605" y1="435" x2="605" y2="495" stroke="#444" stroke-width="1"/>
<line x1="605" y1="515" x2="605" y2="615" stroke="#444" stroke-width="1"/>
<line x1="625" y1="435" x2="625" y2="495" stroke="#444" stroke-width="1"/>
<line x1="625" y1="515" x2="625" y2="615" stroke="#444" stroke-width="1"/>
<line x1="645" y1="435" x2="645" y2="495" stroke="#444" stroke-width="1"/>
<line x1="645" y1="515" x2="645" y2="615" stroke="#444" stroke-width="1"/>
<!-- Dots / Ellipses - TRANSPOSED -->
<!-- Original: horizontal dots at (260, 472) and (260, 552) in vertical gap -->
<!-- Offsets: (70, 27) and (70, 107) → transposed to (27+545, 70+435) and (107+545, 70+435) -->
<text x="572" y="505" class="dots-text" style="font-size: 14px;"></text>
<text x="652" y="505" class="dots-text" style="font-size: 14px;"></text>
<!-- Original: vertical dots at (220, 517) and (340, 517) in horizontal gap -->
<!-- Offsets: (30, 72) and (150, 72) → transposed to (72+545, 30+435) and (72+545, 150+435) -->
<text x="617" y="465" class="dots-text" style="font-size: 14px;"></text>
<text x="617" y="585" class="dots-text" style="font-size: 14px;"></text>
<!-- Diagonal dot at (260, 517) → offset (70, 72) → transposed to (72+545, 70+435) -->
<text x="617" y="505" class="dots-text" style="font-size: 14px;" transform="rotate(45 617 505)"></text>
<!-- Boundaries around white spaces - TRANSPOSED -->
<!-- Original vertical boundaries (x=250, x=270) become horizontal boundaries (y=495, y=515) -->
<line x1="545" y1="495" x2="605" y2="495" stroke="#444" stroke-width="2"/>
<line x1="625" y1="495" x2="665" y2="495" stroke="#444" stroke-width="2"/>
<line x1="545" y1="515" x2="605" y2="515" stroke="#444" stroke-width="2"/>
<line x1="625" y1="515" x2="665" y2="515" stroke="#444" stroke-width="2"/>
<!-- Original horizontal boundaries (y=505, y=525) become vertical boundaries (x=605, x=625) -->
<line x1="605" y1="435" x2="605" y2="495" stroke="#444" stroke-width="2"/>
<line x1="605" y1="515" x2="605" y2="615" stroke="#444" stroke-width="2"/>
<line x1="625" y1="435" x2="625" y2="495" stroke="#444" stroke-width="2"/>
<line x1="625" y1="515" x2="625" y2="615" stroke="#444" stroke-width="2"/>
<!-- Main outline -->
<rect x="545" y="435" width="120" height="180" fill="none" stroke="#444" stroke-width="2"/>
</g>
</g>
</svg>
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
# Check for Hopper or newer GPU
major, minor = torch.cuda.get_device_capability()
assert major >= 9, f"FP8 Blockwise Scaling requires SM90 (Hopper) or later, got SM{major}{minor}"
# START_BLOCKWISE_SCALING_EXAMPLE
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Float8BlockScaling
# Create FP8 Blockwise Scaling recipe
recipe = Float8BlockScaling(
fp8_format=te.common.recipe.Format.E4M3, # E4M3 or HYBRID (default: E4M3)
x_block_scaling_dim=1, # 1D scaling for activations (default: 1)
w_block_scaling_dim=2, # 2D scaling for weights (default: 2)
grad_block_scaling_dim=1, # 1D scaling for gradients (default: 1)
)
# Create a linear layer with bfloat16 parameters
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
# Forward and backward pass
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
loss = output.sum()
loss.backward()
# END_BLOCKWISE_SCALING_EXAMPLE
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
FP8 Current Scaling
===================================
FP8 current scaling recipe is the simplest low precision recipe provided by Transformer Engine.
To understand how this recipe works, we first need to examine what the FP8 data type is and how it differs from other floating point formats.
FP8 data type
-------------
The FP8 datatype, introduced in Hopper architecture, is actually 2 distinct datatypes, useful in different parts of the training of neural networks:
* E4M3 -- consists of 1 sign bit, 4 exponent bits and 3 bits of mantissa. It can store values up to +/-448 and ``nan``.
* E5M2 -- consists of 1 sign bit, 5 exponent bits and 2 bits of mantissa. It can store values up to +/-57344, +/- ``inf`` and ``nan``. The tradeoff of the increased dynamic range is lower precision of the stored values.
.. raw:: html
:file: img/fp8_formats.svg
*Figure 1: Structure of the floating point datatypes. All of the values shown (in FP16, BF16, FP8 E4M3 and FP8 E5M2) are the closest representations of value 0.3952.*
**E4M3 and E5M2 usage in training**
By default, Transformer Engine uses a hybrid approach:
* *Forward pass* - activations and weights require more precision, so E4M3 datatype is used to store them.
* *Backward pass* - gradients are less susceptible to precision loss but require higher dynamic range, so E5M2 datatype is preferred.
The user can configure this behavior via the ``fp8_format`` parameter of the recipe.
Scaling factors
---------------
Limited dynamic range of FP8 datatype is insufficient for many tensors.
To address this, values in the tensor are scaled. FP8 Current Scaling recipe uses one **FP32** scale factor per tensor. The representation of a tensor element ``x`` in FP8 precision is given by:
.. code-block:: python
x = x_fp8 * s
where
* ``x_fp8`` is the FP8 value (E4M3 or E5M2),
* ``s`` is a global **FP32** scaling factor applied to the entire tensor.
**FP8 Current Scaling quantization**
Let's take a closer look at how quantization to FP8 with scaling factor is implemented in
the FP8 Current Scaling recipe.
.. raw:: html
:file: img/fp8_scaling_concept.svg
*Figure 3: Quantization to FP8 consists of amax (absolute maximum) computation, scaling to fit the FP8 range and casting to the respective FP8 format.*
Quantization to FP8 consists of 3 steps:
1. Computation of the absolute maximum value of the tensor - we refer to it as ``amax``.
2. Applying the scaling factor of ``fp8_max / amax`` to the tensor, to fit it into the FP8 range
3. Casting into the respective FP8 format using *Round To Nearest Even (RTNE)*. Values round to the nearest representable FP8 value. When exactly halfway between two values, rounds to the one with even mantissa to minimize systematic bias.
**Performance analysis**
Quantization is a memory-bound operation that requires reading the tensor twice:
* First read: compute ``amax`` across all elements.
* Second read: apply the scaling factor and cast to FP8.
This is a significant overhead compared to other recipes, which typically require only a single memory read.
.. raw:: html
:file: img/fp8_cast_process.svg
*Figure 4: FP8 quantization with current scaling recipe - two tensor reads are needed, one to compute amax and one to apply the scaling factor and cast to FP8.*
Transpose handling
------------------
*Ada and Hopper*
On Ada and Hopper, the backward pass requires a transposed FP8 tensor.
The columnwise layout is physically different from the rowwise layout, so a transpose operation is needed.
All 3 options from :ref:`Performance Considerations Transpose handling section <handling_transposes>` are supported.
*Blackwell and later*
Blackwell hardware supports multiple GEMM layouts natively, eliminating the need for explicit transposes.
The rowwise and columnwise tensors share the same physical memory layout.
.. figure:: ../performance_considerations/img/hopper_vs_blackwell_layout.svg
:align: center
:alt: Comparison of rowwise and columnwise tensor layouts on Blackwell vs Hopper
*Figure 6: On Blackwell, rowwise and columnwise usages share the same memory layout. On Hopper, columnwise usage requires a physical transpose.*
Distributed training
--------------------
**Quantized all-gather**
FP8 all-gather is supported on all architectures (Ada and later).
**Amax reduction**
Tensors that are gathered across nodes (e.g. input and gradient in sequence parallelism) require amax synchronization before quantization.
Each node computes its local ``amax``, then a reduction produces the global maximum across all nodes.
All nodes use this synchronized amax to compute identical scaling factors, enabling quantized all-gather.
.. raw:: html
:file: img/fp8_current_scaling_all_gather.svg
*Figure 7: Quantization and all-gather flow for FP8 current scaling showing amax computation and synchronization.*
Supported devices
-----------------
Ada and later (SM 8.9+)
Examples
--------
Here's how to use FP8 Current Scaling recipe in PyTorch and JAX:
.. tabs::
.. tab:: PyTorch
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Requires SM89 (Ada) or later
</div>
.. literalinclude:: pytorch_current_scaling_example.py
:language: python
:start-after: # START_CURRENT_SCALING_EXAMPLE
:end-before: # END_CURRENT_SCALING_EXAMPLE
.. tab:: JAX
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Requires SM89 (Ada) or later
</div>
.. literalinclude:: jax_current_scaling_example.py
:language: python
:start-after: # START_CURRENT_SCALING_EXAMPLE
:end-before: # END_CURRENT_SCALING_EXAMPLE
----
Developer Notes
---------------
This section contains implementation details that may be useful for developers
but are not required for using FP8 Current Scaling in practice.
All-gather of columnwise tensors
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
On Blackwell and later, rowwise and columnwise tensors share the same memory layout,
so all-gather of columnwise tensors is directly supported.
For Hopper and Ada, all-gather of transposed FP8 tensors is not supported.
The rowwise tensor is gathered first, then transposed to columnwise format.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment