Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
......@@ -19,7 +19,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake==3.21.0 pybind11[global] ninja
pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
......@@ -43,7 +43,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
......@@ -63,7 +63,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install pybind11[global]
run: pip install pybind11[global] nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
......@@ -83,7 +83,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install torch pybind11[global] einops onnxscript
run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
......
......@@ -38,3 +38,9 @@ repos:
entry: clang-format -i
args: ["-style=file"]
files: ^transformer_engine.*\.(c|cc|cxx|cpp|cu|cuh|h|hpp)$
- repo: https://github.com/netromdk/vermin
rev: c75aca72f4e85c6e47252139e8695f1c8b5f9ae3
hooks:
- id: vermin
args: ['-t=3.10', '--violations']
......@@ -12,6 +12,14 @@ Transformer Engine
Latest News
===========
* [09/2025] `Pretraining Large Language Models with NVFP4 <https://www.arxiv.org/pdf/2509.25149>`_
* [09/2025] `Native FP8 Mixed Precision Training for Ling 2.0, Open Sourced! <https://huggingface.co/blog/im0qianqian/ling-mini-2-fp8-mixed-precision-training-solution>`_
* [09/2025] `Faster Training Throughput in FP8 Precision with NVIDIA NeMo <https://developer.nvidia.com/blog/faster-training-throughput-in-fp8-precision-with-nvidia-nemo/>`_
* [08/2025] `How we built DeepL's next-generation LLMs with FP8 for training and inference <https://www.deepl.com/en/blog/tech/next-generation-llm-fp8-training>`_
* [08/2025] `NVFP4 Trains with Precision of 16-bit and Speed and Efficiency of 4-bit <https://developer.nvidia.com/blog/nvfp4-trains-with-precision-of-16-bit-and-speed-and-efficiency-of-4-bit/>`_
* [06/2025] `Floating Point 8: An Introduction to Efficient, Lower-Precision AI Training <https://developer.nvidia.com/blog/floating-point-8-an-introduction-to-efficient-lower-precision-ai-training/>`_
* [05/2025] `Advanced Optimization Strategies for LLM Training on NVIDIA Grace Hopper <https://developer.nvidia.com/blog/advanced-optimization-strategies-for-llm-training-on-nvidia-grace-hopper/>`_
* [03/2025] `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc25-s72778/>`_
* [03/2025] `Measure and Improve AI Workload Performance with NVIDIA DGX Cloud Benchmarking <https://developer.nvidia.com/blog/measure-and-improve-ai-workload-performance-with-nvidia-dgx-cloud-benchmarking/>`_
......@@ -86,7 +94,7 @@ PyTorch
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with te.autocast(enabled=True, recipe=fp8_recipe):
out = model(inp)
loss = out.sum()
......@@ -121,7 +129,7 @@ Flax
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)
# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with te.autocast(enabled=True, recipe=fp8_recipe):
model = te_flax.DenseGeneral(features=HIDDEN)
def loss_fn(params, other_vars, inp):
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import argparse
import torch
import pandas as pd
import torch.utils.benchmark as benchmark
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
scale_padding_to = 1
permute_scale = False
TORCH_TO_TE_FLOAT_MAP = {
torch.bfloat16: tex.DType.kBFloat16,
}
def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16):
# Generate random input data
M, K = shape
x = torch.randn([M, K], dtype=input_dtype, device="cuda")
assert shape[0] % 16 == 0, "Shape must be divisible by 16"
assert shape[1] % 16 == 0, "Shape must be divisible by 16"
# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=True,
with_post_rht_amax=True,
with_random_sign_mask=True,
stochastic_rounding=stochastic_rounding,
)
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, K), dtype=x.dtype, device=x.device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
with torch.no_grad():
stmt = "kernel_func(input, output)"
globals_dict = {
"kernel_func": nvfp4_quantizer.update_quantized,
"input": x,
"output": x_nvfp4_sut,
}
timing = benchmark.Timer(
stmt=stmt,
globals=globals_dict,
num_threads=1,
).blocked_autorange(min_run_time=5)
print(timing)
timing_us = timing.median * 1e6
input_nbytes = shape[0] * shape[1] * 2 # bf16
output_nbytes = shape[0] * shape[1] // 2 # //2 for fp4
sf_nbytes = shape[0] * shape[1] // 16 # //16 for 1 byte per 16 elems
total_nbytes = (
0
+ input_nbytes
* 3 # Reading input for Amax(x)&Amax(RHT(x.T)), Reading input for Cast(x), Reaindg input for Cast(RHT(x.T))
+ 2 * 4 # Output 2 * float for scale & amax
+ 2 * 4 # Input 2 * float
+ output_nbytes * 2 # Output from Cast(x) and Cast(RHT(x.T))
+ sf_nbytes * 2 # Scale factor
)
throughput_GBps = total_nbytes / (1024 * 1024 * 1024) / (timing_us / 1e6)
print(
f"Stochastic rounding: {stochastic_rounding}, Total: {total_nbytes} bytes, Throughput:"
f" {throughput_GBps} GB/s"
)
return timing_us, throughput_GBps
# Nsight Compute Profiling Command:
# ncu -f -o block_scaled_1d_cast_transpose_kernel --set=full --kernel-name "block_scaled_1d_cast_transpose_kernel" -s 5 -c 5 python benchmark_cast_transpose_1d_block.py --profile
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
args = parser.parse_args()
if args.profile:
print("Profiling is enabled.")
else:
print("Profiling is disabled.")
shapes = [
(8192, 5120),
(8192, 10240),
(8192, 2560),
(8192, 11328),
(8192, 512),
(8192, 3584),
(5120, 8192),
(10240, 8192),
(2560, 8192),
(11328, 8192),
(512, 8192),
(3584, 8192),
(4096, 16384),
(14336, 16384),
]
if args.profile:
shapes = [
(16384, 6144),
]
data = []
for stochastic_rounding in [True]: # , False]:
for shape in shapes:
print(
f"Running benchmark_func with shape {shape} and stochastic_rounding"
f" {stochastic_rounding}"
)
timing_us, throughput_GBps = run_kernel(shape, stochastic_rounding)
data.append(
[
"benchmark_func",
shape,
stochastic_rounding,
timing_us,
throughput_GBps,
]
)
df = pd.DataFrame(
data=data,
columns=[
"kernel",
"shape",
"stochastic_rounding",
"timing_us",
"throughput(GB/s)",
],
)
print(df)
df.to_csv("benchmark_cast_nvfp4.csv", index=False)
......@@ -6,11 +6,10 @@ import argparse
import torch
import torch.utils.benchmark as benchmark
import pandas as pd
import pathlib
from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.common.recipe import Float8BlockScaling, MXFP8BlockScaling
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.quantization import autocast, FP8GlobalStateManager
from contextlib import nullcontext
"""
......@@ -51,9 +50,7 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None):
assert mode in ["fwd_only", "fwd_bwd"]
fp8_context = (
fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext()
)
fp8_context = autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext()
# print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}")
if mode == "fwd_only":
......
......@@ -59,6 +59,7 @@ class CMakeExtension(setuptools.Extension):
build_dir,
f"-DPython_EXECUTABLE={sys.executable}",
f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}",
f"-DPython_SITEARCH={sysconfig.get_path('platlib')}",
f"-DCMAKE_BUILD_TYPE={build_type}",
f"-DCMAKE_INSTALL_PREFIX={install_dir}",
]
......
......@@ -87,4 +87,5 @@ def setup_jax_extension(
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args=cxx_flags,
libraries=["nccl"],
)
......@@ -19,7 +19,7 @@ def install_requirements() -> List[str]:
def test_requirements() -> List[str]:
"""Test dependencies for TE/JAX extensions."""
return ["numpy", "torchvision", "transformers"]
return ["numpy", "torchvision", "transformers", "torchao==0.13"]
def setup_pytorch_extension(
......
......@@ -12,12 +12,31 @@ import re
import shutil
import subprocess
import sys
import platform
from pathlib import Path
from importlib.metadata import version as get_version
from subprocess import CalledProcessError
from typing import List, Optional, Tuple, Union
# Needs to stay consistent with .pre-commit-config.yaml config.
def min_python_version() -> Tuple[int]:
"""Minimum supported Python version."""
return (3, 10, 0)
def min_python_version_str() -> str:
"""String representing minimum supported Python version."""
return ".".join(map(str, min_python_version()))
if sys.version_info < min_python_version():
raise RuntimeError(
f"Transformer Engine requires Python {min_python_version_str()} or newer, "
f"but found Python {platform.python_version()}."
)
@functools.lru_cache(maxsize=None)
def debug_build_enabled() -> bool:
"""Whether to build with a debug configuration"""
......@@ -272,15 +291,18 @@ def get_cuda_include_dirs() -> Tuple[str, str]:
@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
archs = os.getenv("NVTE_CUDA_ARCHS")
if archs is None:
version = cuda_version()
if os.getenv("NVTE_CUDA_ARCHS") is None:
if version >= (13, 0):
os.environ["NVTE_CUDA_ARCHS"] = "75;80;89;90;100;120"
archs = "75;80;89;90;100;100a;103a;120"
elif version >= (12, 9):
archs = "70;80;89;90;100;100a;103a;120"
elif version >= (12, 8):
os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90;100;120"
archs = "70;80;89;90;100;100a;120"
else:
os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90"
return os.getenv("NVTE_CUDA_ARCHS")
archs = "70;80;89;90"
return archs
def cuda_version() -> Tuple[int, ...]:
......
......@@ -12,6 +12,8 @@ Common API
.. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3)
.. autoapiclass:: transformer_engine.common.recipe.NVFP4BlockScaling(fp4_format=Format.E2M1)
.. autoapiclass:: transformer_engine.common.recipe.Float8CurrentScaling(fp8_format=Format.HYBRID)
.. autoapiclass:: transformer_engine.common.recipe.Float8BlockScaling(fp8_format=Format.E4M3)
......@@ -30,6 +30,7 @@ Modules
.. autoapifunction:: transformer_engine.jax.fp8_autocast
.. autoapifunction:: transformer_engine.jax.autocast
.. autoapifunction:: transformer_engine.jax.update_collections
......
......@@ -41,8 +41,28 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.fp8_model_init
.. autoapifunction:: transformer_engine.pytorch.autocast
.. autoapifunction:: transformer_engine.pytorch.quantized_model_init
.. autoapifunction:: transformer_engine.pytorch.checkpoint
.. autoapifunction:: transformer_engine.pytorch.is_fp8_available
.. autoapifunction:: transformer_engine.pytorch.is_mxfp8_available
.. autoapifunction:: transformer_engine.pytorch.is_fp8_block_scaling_available
.. autoapifunction:: transformer_engine.pytorch.is_nvfp4_available
.. autoapifunction:: transformer_engine.pytorch.is_bf16_available
.. autoapifunction:: transformer_engine.pytorch.get_cudnn_version
.. autoapifunction:: transformer_engine.pytorch.get_device_compute_capability
.. autoapifunction:: transformer_engine.pytorch.get_default_recipe
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
......
......@@ -69,7 +69,7 @@ Let's look at a simple example of training a Transformer layer using Transformer
for epoch in range(5):
transformer_layer.train()
optimizer.zero_grad()
with te.fp8_autocast(enabled=True):
with te.autocast(enabled=True):
output = transformer_layer(dummy_input)
loss = criterion(output, dummy_target)
loss.backward()
......
......@@ -71,7 +71,7 @@
" amax_compute_algo=\"max\",\n",
")\n",
"# Training step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" y = basic_transformer(x, attention_mask=None)\n",
"y.backward(dy)\n",
"\n",
......@@ -81,7 +81,7 @@
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n",
")"
]
},
......@@ -135,7 +135,7 @@
"\n",
"Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n",
"\n",
"One important consideration for multi-GPU FP8 training is how to synchronize the FP8 scaling factors between GPUs. If tensor parallelism is enabled, the scales must be synchronized over the tensor-parallel group. However, synchronizing over both the data-parallel and tensor-parallel groups is recommended for the best convergence. This can be configured with the **fp8_group** argument in the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager."
"One important consideration for multi-GPU FP8 training is how to synchronize the FP8 scaling factors between GPUs. If tensor parallelism is enabled, the scales must be synchronized over the tensor-parallel group. However, synchronizing over both the data-parallel and tensor-parallel groups is recommended for the best convergence. This can be configured with the **fp8_group** argument in the [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager."
]
},
{
......@@ -169,7 +169,7 @@
")\n",
"\n",
"# Training step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=world_group):\n",
"with te.autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=world_group):\n",
" y = parallel_transformer(x, attention_mask=None)\n",
"y.backward(dy)\n",
"\n",
......@@ -179,10 +179,10 @@
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = {\n",
" autocast_kwargs = {\n",
" \"enabled\": True,\n",
" \"fp8_recipe\": fp8_recipe,\n",
" \"fp8_group\": world_group,\n",
" \"recipe\": fp8_recipe,\n",
" \"amax_reduction_group\": world_group,\n",
" },\n",
")"
]
......@@ -234,7 +234,7 @@
" param.main_grad = torch.zeros_like(param, dtype=torch.float32)\n",
"\n",
"# Training step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" y = wgrad_transformer(x, attention_mask=None)\n",
"y.backward(dy)\n",
"for param in wgrad_transformer.parameters():\n",
......@@ -248,7 +248,7 @@
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n",
")"
]
},
......@@ -268,7 +268,7 @@
"\n",
"</div>\n",
"\n",
"Since weights are typically trained in FP32, a type conversion is required before we can perform compute in FP8. By default, the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager will handle this internally by casting non-FP8 tensors to FP8 as they are encountered. However, we can improve upon this in some cases. In particular, if our training iteration is split into multiple gradient accumulation steps, each micro-batch will encounter the same weight tensors. Thus, we only need to cast the weights to FP8 in the first gradient accumulation step and we can cache the resulting FP8 weights for the remaining gradient accumulation steps.\n",
"Since weights are typically trained in FP32, a type conversion is required before we can perform compute in FP8. By default, the [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager will handle this internally by casting non-FP8 tensors to FP8 as they are encountered. However, we can improve upon this in some cases. In particular, if our training iteration is split into multiple gradient accumulation steps, each micro-batch will encounter the same weight tensors. Thus, we only need to cast the weights to FP8 in the first gradient accumulation step and we can cache the resulting FP8 weights for the remaining gradient accumulation steps.\n",
"\n",
"<div class=\"alert alert-warning\">\n",
"\n",
......@@ -303,12 +303,12 @@
"weight_caching_transformer.to(dtype=dtype).cuda()\n",
"\n",
"# Cast weights in first gradient accumulation step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" y = weight_caching_transformer(x, attention_mask=None, is_first_microbatch=True)\n",
"y.backward(dy)\n",
"\n",
"# Reuse FP8 weights in subsequent gradient accumulation steps\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" y = weight_caching_transformer(x, attention_mask=None, is_first_microbatch=False)\n",
"y.backward(dy)\n",
"\n",
......@@ -318,7 +318,7 @@
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None, \"is_first_microbatch\": False },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n",
")"
]
}
......
......@@ -5,9 +5,9 @@
"id": "7b3e6954",
"metadata": {},
"source": [
"# Using FP8 with Transformer Engine\n",
"# Using FP8 and FP4 with Transformer Engine\n",
"\n",
"H100 GPU introduced support for a new datatype, FP8 (8-bit floating point), enabling higher throughput of matrix multiplies and convolutions. In this example we will introduce the FP8 datatype and show how to use it with Transformer Engine.\n",
"H100 GPU introduced support for a new datatype, FP8 (8-bit floating point), enabling higher throughput of matrix multiplies and convolutions. Blackwell added support for NVFP4 and MXFP8 datatypes. In this example we will introduce these low precision datatypes and show how to use them with Transformer Engine.\n",
"\n",
"## Introduction to FP8\n",
"\n",
......@@ -100,19 +100,66 @@
"</figure>"
]
},
{
"cell_type": "markdown",
"id": "fd7b4f37-50a2-4d41-9067-cf0c471cb2d7",
"metadata": {},
"source": [
"## Beyond FP8 - training with NVFP4\n",
"\n",
"In addition to MXFP8, NVIDIA Blackwell introduced support for an even smaller, 4-bit format called NVFP4. The values are represented there in E2M1 format, able to represent values of magnitude up to +/-6.\n",
"\n",
"<figure align=\"center\" id=\"fig_8\">\n",
"<img src=\"FP4_format.png\" width=\"50%\">\n",
"<figcaption> Figure 8: FP4 E2M1 format can represent values between +/-6.</figcaption>\n",
"</figure>\n",
"\n",
"### NVFP4 Format\n",
"\n",
"NVFP4 format is similar to MXFP8 - it also uses granular scaling to preserve the dynamic range. The differences are:\n",
"\n",
" - Granularity of the scaling factors: in NVFP4 format a single scaling factor is used per block of 16 elements, whereas MXFP8 uses 1 scaling factor per block of 32 elements\n",
" - Datatype of the scaling factors: NVFP4 uses FP8 E4M3 as the scaling factor per block, whereas MXFP8 uses E8M0 as the scaling factor datatype. Choice of E4M3 for the scaling factor enables preservation of more information about mantissa, but does not enable the full dynamic range of FP32. Therefore, NVFP4 uses an additional single per-tensor FP32 scaling factor to avoid overflows.\n",
"\n",
"In the NVFP4 training recipe for weight tensors we use a different variant of the NVFP4 quantization, where a single scaling factor is shared by a 2D block of 16x16 elements. This is similar to the weight quantization scheme employed in [DeepSeek-v3 training](https://arxiv.org/abs/2412.19437v1), but with a much finer granularity.\n",
"\n",
"### NVFP4 training recipe\n",
"\n",
"The NVFP4 training recipe implemented in Transformer Engine is described in [Pretraining Large Language Models with NVFP4](https://arxiv.org/abs/2509.25149v1) paper. The main elements of the recipe are:\n",
"\n",
" - Stochastic Rounding. When quantizing gradients to NVFP4, we use stochastic rounding to avoid the bias introduced by quantization. With stochastic rounding values are rounded probabilistically to one of their two nearest representable numbers, with probabilities inversely\n",
"proportional to their distances.\n",
" - 2D Scaling. The non-square size of the quantization blocks, while increasing granularity, has a property that the quantized tensor and its transpose no longer hold the same values. This is important since the transposed tensors are used when calculating gradients of the linear layers. While most tensors are not sensitive to this issue during training, it does affect the training accuracy when applied to the weight tensors. Therefore, the weights of the linear layers are quantized using a 2D scheme, where a single scaling factor is shared by a 2D block of 16x16 elements.\n",
" - Random Hadamard Transforms. While microscaling reduces the dynamic range needed to represent tensor values, outliers can still have a\n",
"disproportionate impact on FP4 formats, degrading model accuracy. Random Hadamard transforms address this by reshaping the tensor distribution to be more Gaussian-like, which smooths outliers and makes tensors easier to represent accurately in NVFP4. In Transformer Engine, we use a 16x16 Hadamard matrix for activations and gradients when performing weight gradient computation.\n",
" - Last few layers in higher precision. The last few layers of the LLM are more sensitive to the quantization and so we recommend running them in higher precision (for example MXFP8). This is not done automatically in Transformer Engine, since TE does not have the full information about the structure of the network being trained. This can be easily achieved though by modifying the model training code to run the last few layers under a different `autocast` (or nesting 2 autocasts in order to override the recipe for a part of the network).\n",
"\n",
"The full linear layer utilizing NVFP4 is presented in Figure 9.\n",
"\n",
"<figure align=\"center\" id=\"fig_9\">\n",
"<img src=\"FP4_linear.png\" width=\"80%\">\n",
"<figcaption> Figure 9: Linear layer utilizing NVFP4</figcaption>\n",
"</figure>"
]
},
{
"cell_type": "markdown",
"id": "cf5e0b0d",
"metadata": {},
"source": [
"## Using FP8 with Transformer Engine\n",
"## Using FP8 and FP4 with Transformer Engine\n",
"\n",
"Transformer Engine library provides tools enabling easy to use training with FP8 datatype using FP8 delayed scaling and MXFP8 strategies.\n",
"Transformer Engine library provides tools enabling easy to use training with FP8 and FP4 datatypes using different strategies.\n",
"\n",
"### FP8 recipe\n",
"\n",
"The [DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe from the `transformer_engine.common.recipe` module stores all of the required options for training with FP8 delayed scaling: length of the amax history to use for scaling factor computation, FP8 data format, etc.\n",
"Similarly, [MXFP8BlockScaling](../api/common.rst#transformer_engine.common.recipe.MXFP8BlockScaling) from the same module may be used to enable MXFP8 training."
"Transformer Engine defines a range of different low precision recipes to choose from in the `transformer_engine.common.recipe` module.\n",
"\n",
" - The [DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe stores all of the required options for training with FP8 delayed scaling: length of the amax history to use for scaling factor computation, FP8 data format, etc.\n",
" - [Float8CurrentScaling](../api/common.rst#transformer_engine.common.recipe.Float8CurrentScaling) recipe enables current per-tensor scaling with FP8.\n",
" - [Float8BlockScaling](../api/common.rst#transformer_engine.common.recipe.Float8BlockScaling) recipe enables block scaling with FP8 as described in [DeepSeek-v3 paper](https://arxiv.org/abs/2412.19437v1).\n",
" - [MXFP8BlockScaling](../api/common.rst#transformer_engine.common.recipe.MXFP8BlockScaling) recipe enables MXFP8 training.\n",
" - [NVFP4BlockScaling](../api/common.rst#transformer_engine.common.recipe.NVFP4BlockScaling) recipe enables NVFP4 training."
]
},
{
......@@ -122,12 +169,13 @@
"metadata": {},
"outputs": [],
"source": [
"from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling\n",
"from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling\n",
"\n",
"fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass\n",
"fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n",
"mxfp8_format = Format.E4M3 # E4M3 used everywhere\n",
"mxfp8_recipe = MXFP8BlockScaling(fp8_format=mxfp8_format)"
"mxfp8_recipe = MXFP8BlockScaling(fp8_format=mxfp8_format)\n",
"nvfp4_recipe = NVFP4BlockScaling()"
]
},
{
......@@ -135,7 +183,7 @@
"id": "f9591eb5",
"metadata": {},
"source": [
"This recipe is then used to configure the FP8 training."
"This recipe is then used to configure the low precision training."
]
},
{
......@@ -145,7 +193,7 @@
"source": [
"### FP8 autocasting\n",
"\n",
"Not every operation is safe to be performed using FP8. All of the modules provided by Transformer Engine library were designed to provide maximum performance benefit from FP8 datatype while maintaining accuracy. In order to enable FP8 operations, TE modules need to be wrapped inside the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager."
"Not every operation is safe to be performed using FP8. All of the modules provided by Transformer Engine library were designed to provide maximum performance benefit from FP8 datatype while maintaining accuracy. In order to enable FP8 operations, TE modules need to be wrapped inside the [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager."
]
},
{
......@@ -164,7 +212,7 @@
"\n",
"inp = torch.rand((1024, 768)).cuda()\n",
"\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" out_fp8 = my_linear(inp)"
]
},
......@@ -173,7 +221,7 @@
"id": "e41161f1",
"metadata": {},
"source": [
"The `fp8_autocast` context manager hides the complexity of handling FP8:\n",
"The `autocast` context manager hides the complexity of handling FP8:\n",
"\n",
"- All FP8-safe operations have their inputs cast to FP8\n",
"- Amax history is updated\n",
......@@ -195,9 +243,9 @@
"source": [
"### Handling backward pass\n",
"\n",
"When a model is run inside the `fp8_autocast` region, especially in multi-GPU training, some communication is required in order to synchronize the scaling factors and amax history. In order to perform that communication without introducing much overhead, `fp8_autocast` context manager aggregates the tensors before performing the communication.\n",
"When a model is run inside the `autocast` region, especially in multi-GPU training, some communication is required in order to synchronize the scaling factors and amax history. In order to perform that communication without introducing much overhead, `autocast` context manager aggregates the tensors before performing the communication.\n",
"\n",
"Due to this aggregation the backward call needs to happen outside of the `fp8_autocast` context manager. It has no impact on the computation precision - the precision of the backward pass is determined by the precision of the forward pass."
"Due to this aggregation the backward call needs to happen outside of the `autocast` context manager. It has no impact on the computation precision - the precision of the backward pass is determined by the precision of the forward pass."
]
},
{
......@@ -209,11 +257,11 @@
"source": [
"loss_fp8 = out_fp8.mean()\n",
"\n",
"loss_fp8.backward() # This backward pass uses FP8, since out_fp8 was calculated inside fp8_autocast\n",
"loss_fp8.backward() # This backward pass uses FP8, since out_fp8 was calculated inside autocast\n",
"\n",
"out_fp32 = my_linear(inp)\n",
"loss_fp32 = out_fp32.mean()\n",
"loss_fp32.backward() # This backward pass does not use FP8, since out_fp32 was calculated outside fp8_autocast"
"loss_fp32.backward() # This backward pass does not use FP8, since out_fp32 was calculated outside autocast"
]
},
{
......@@ -235,13 +283,13 @@
{
"data": {
"text/plain": [
"tensor([[ 0.2276, 0.2627, 0.3001, ..., 0.0346, 0.2211, 0.1188],\n",
" [-0.0963, -0.3725, 0.1717, ..., 0.0901, 0.0522, -0.3472],\n",
" [ 0.4526, 0.3482, 0.5976, ..., -0.0687, -0.0382, 0.1566],\n",
"tensor([[ 0.2276, 0.2629, 0.3000, ..., 0.1297, -0.3702, 0.1807],\n",
" [-0.0963, -0.3724, 0.1717, ..., -0.1250, -0.8501, -0.1669],\n",
" [ 0.4526, 0.3479, 0.5976, ..., 0.1685, -0.8864, -0.1977],\n",
" ...,\n",
" [ 0.1698, 0.6061, 0.0385, ..., -0.2875, -0.1152, -0.0260],\n",
" [ 0.0679, 0.2946, 0.2751, ..., -0.2284, 0.0517, -0.1441],\n",
" [ 0.1865, 0.2353, 0.9172, ..., 0.1085, 0.1135, 0.1438]],\n",
" [ 0.1698, 0.6062, 0.0385, ..., 0.4038, -0.4564, 0.0143],\n",
" [ 0.0679, 0.2947, 0.2750, ..., -0.3271, -0.4990, 0.1198],\n",
" [ 0.1865, 0.2353, 0.9170, ..., 0.0673, -0.5567, 0.1246]],\n",
" device='cuda:0', grad_fn=<_LinearBackward>)"
]
},
......@@ -263,13 +311,13 @@
{
"data": {
"text/plain": [
"tensor([[ 0.2373, 0.2674, 0.2980, ..., 0.0233, 0.2498, 0.1131],\n",
" [-0.0767, -0.3778, 0.1862, ..., 0.0858, 0.0676, -0.3369],\n",
" [ 0.4615, 0.3593, 0.5813, ..., -0.0779, -0.0349, 0.1422],\n",
"tensor([[ 0.2373, 0.2674, 0.2980, ..., 0.1134, -0.3661, 0.1650],\n",
" [-0.0767, -0.3778, 0.1862, ..., -0.1370, -0.8448, -0.1770],\n",
" [ 0.4615, 0.3593, 0.5813, ..., 0.1696, -0.8826, -0.1826],\n",
" ...,\n",
" [ 0.1914, 0.6038, 0.0382, ..., -0.2847, -0.0991, -0.0423],\n",
" [ 0.0864, 0.2895, 0.2719, ..., -0.2388, 0.0772, -0.1541],\n",
" [ 0.2019, 0.2275, 0.9027, ..., 0.1022, 0.1300, 0.1444]],\n",
" [ 0.1914, 0.6038, 0.0382, ..., 0.4049, -0.4729, 0.0118],\n",
" [ 0.0864, 0.2895, 0.2719, ..., -0.3337, -0.4922, 0.1240],\n",
" [ 0.2019, 0.2275, 0.9027, ..., 0.0706, -0.5481, 0.1356]],\n",
" device='cuda:0', grad_fn=<_LinearBackward>)"
]
},
......@@ -300,13 +348,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 0.2276, 0.2629, 0.3000, ..., 0.0346, 0.2211, 0.1188],\n",
" [-0.0963, -0.3724, 0.1717, ..., 0.0901, 0.0522, -0.3470],\n",
" [ 0.4526, 0.3479, 0.5976, ..., -0.0686, -0.0382, 0.1566],\n",
"tensor([[ 0.2276, 0.2629, 0.3000, ..., 0.1297, -0.3702, 0.1807],\n",
" [-0.0963, -0.3724, 0.1717, ..., -0.1250, -0.8501, -0.1669],\n",
" [ 0.4526, 0.3479, 0.5976, ..., 0.1685, -0.8864, -0.1977],\n",
" ...,\n",
" [ 0.1698, 0.6062, 0.0385, ..., -0.2876, -0.1152, -0.0260],\n",
" [ 0.0679, 0.2947, 0.2750, ..., -0.2284, 0.0516, -0.1441],\n",
" [ 0.1865, 0.2353, 0.9170, ..., 0.1085, 0.1135, 0.1438]],\n",
" [ 0.1698, 0.6062, 0.0385, ..., 0.4038, -0.4564, 0.0143],\n",
" [ 0.0679, 0.2947, 0.2750, ..., -0.3271, -0.4990, 0.1198],\n",
" [ 0.1865, 0.2353, 0.9170, ..., 0.0673, -0.5567, 0.1246]],\n",
" device='cuda:0', grad_fn=<_LinearBackward>)\n"
]
}
......@@ -339,19 +387,14 @@
{
"data": {
"text/plain": [
"tensor([[ 4.9591e-05, -1.9073e-04, 9.5367e-05, ..., -3.8147e-06,\n",
" 4.1962e-05, 2.2888e-05],\n",
" [ 2.2888e-05, -3.4332e-05, 2.2888e-05, ..., 2.6703e-05,\n",
" 5.3406e-05, -1.4114e-04],\n",
" [-3.8147e-05, 2.6703e-04, -3.8147e-06, ..., -5.7220e-05,\n",
" 4.1962e-05, -1.9073e-05],\n",
"tensor([[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" ...,\n",
" [ 1.1444e-05, -7.2479e-05, -3.8147e-06, ..., 5.3406e-05,\n",
" -1.5259e-05, 2.2888e-05],\n",
" [ 4.9591e-05, -9.5367e-05, 6.8665e-05, ..., -1.5259e-05,\n",
" 7.6294e-05, 4.5776e-05],\n",
" [-1.5259e-05, -7.6294e-06, 1.8692e-04, ..., -3.0518e-05,\n",
" -4.5776e-05, 7.6294e-06]], device='cuda:0', grad_fn=<SubBackward0>)"
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0',\n",
" grad_fn=<SubBackward0>)"
]
},
"execution_count": 7,
......@@ -370,6 +413,53 @@
"source": [
"The differences in result coming from FP8 execution do not matter during the training process, but it is good to understand them, e.g. during debugging the model."
]
},
{
"cell_type": "markdown",
"id": "d45e8b6c-803b-4a4f-8835-c19b0a94bc6a",
"metadata": {},
"source": [
"### Using multiple recipes in the same training run\n",
"\n",
"Sometimes it is desirable to use multiple recipes in the same training run. An example of this is the NVFP4 training, where a few layers at the end of the training should be run in higher precision. This can be achieved by using multiple autocasts, either completely separately or in a nested way (this could be useful when e.g. we want to have a configurable overarching recipe but still hardcode a different recipe for some pieces of the network)."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c663f694-41d6-47c0-a397-5fc56e692542",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 0.0547, 0.0039, -0.0664, ..., -0.2061, 0.2344, -0.3223],\n",
" [ 0.0131, -0.1436, 0.0168, ..., -0.4258, 0.1562, -0.0371],\n",
" [ 0.1074, -0.2773, 0.0576, ..., -0.2070, 0.0640, -0.1611],\n",
" ...,\n",
" [ 0.0825, -0.0630, 0.0571, ..., -0.3711, 0.1562, -0.4062],\n",
" [-0.1729, -0.1138, -0.0620, ..., -0.4238, 0.0703, -0.2070],\n",
" [-0.0908, -0.2148, 0.2676, ..., -0.4551, 0.1836, -0.4551]],\n",
" device='cuda:0', dtype=torch.bfloat16, grad_fn=<_LinearBackward>)\n"
]
}
],
"source": [
"my_linear1 = te.Linear(768, 768).bfloat16() # The first linear - we want to run it in FP4\n",
"my_linear2 = te.Linear(768, 768).bfloat16() # The second linear - we want to run it in MXFP8\n",
"\n",
"inp = inp.bfloat16()\n",
"\n",
"with te.autocast(recipe=nvfp4_recipe):\n",
" y = my_linear1(inp)\n",
" with te.autocast(recipe=mxfp8_recipe):\n",
" out = my_linear2(y)\n",
"\n",
"print(out)\n",
"\n",
"out.mean().backward()"
]
}
],
"metadata": {
......
......@@ -80,7 +80,7 @@
"model = Model().eval().cuda()\n",
"inps = (torch.randn([S, B, H], device=\"cuda\"),)\n",
"def _inference(fp8_enabled):\n",
" with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8_enabled):\n",
" with torch.no_grad(), te.pytorch.autocast(enabled=fp8_enabled):\n",
" model(*inps)\n",
"\n",
"te_fp32_time = _measure_time(lambda: _inference(fp8_enabled=False))\n",
......@@ -138,7 +138,7 @@
"from transformer_engine.pytorch.export import te_translation_table\n",
"\n",
"def export(model, fname, inputs, fp8=True):\n",
" with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8):\n",
" with torch.no_grad(), te.pytorch.autocast(enabled=fp8):\n",
" # ! IMPORTANT !\n",
" # Transformer Engine models must have warm-up run\n",
" # before export. FP8 recipe during warm-up should \n",
......
......@@ -548,7 +548,7 @@
"\n",
"</div>\n",
"\n",
"Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Note that fp8_autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options."
"Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager. Note that autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options."
]
},
{
......@@ -567,7 +567,7 @@
"fp8_format = Format.HYBRID\n",
"fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n",
"torch.manual_seed(1234)\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
"with te.autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" y = te_transformer(x, attention_mask=None)"
]
},
......@@ -591,7 +591,7 @@
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n",
")"
]
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment