"tests/pytorch/test_quantized_tensor.py" did not exist on "1e7809460157f5d641fbd7ac1543d68648a57558"
Unverified Commit ee4a17de authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Update documentation for 2.0 release (#1479)



* Updated docs for TE 2.0
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Do not expose comm_gemm_overlap and cast_transpose_noop
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Made the figures larger
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>

* Update quickstart_utils.py
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Change from review
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 49a4535d
...@@ -33,11 +33,12 @@ What is Transformer Engine? ...@@ -33,11 +33,12 @@ What is Transformer Engine?
.. overview-begin-marker-do-not-remove .. overview-begin-marker-do-not-remove
Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including
using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower using 8-bit floating point (FP8) precision on Hopper, Ada, and Blackwell GPUs, to provide better
memory utilization in both training and inference. TE provides a collection of highly optimized performance with lower memory utilization in both training and inference. TE provides a collection
building blocks for popular Transformer architectures and an automatic mixed precision-like API that of highly optimized building blocks for popular Transformer architectures and an automatic mixed
can be used seamlessly with your framework-specific code. TE also includes a framework agnostic precision-like API that can be used seamlessly with your framework-specific code. TE also includes a
C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers. framework agnostic C++ API that can be integrated with other deep learning libraries to enable FP8
support for Transformers.
As the number of parameters in Transformer models continues to grow, training and inference for As the number of parameters in Transformer models continues to grow, training and inference for
architectures such as BERT, GPT and T5 become very memory and compute-intensive. Most deep learning architectures such as BERT, GPT and T5 become very memory and compute-intensive. Most deep learning
...@@ -51,16 +52,16 @@ not available natively in frameworks today. ...@@ -51,16 +52,16 @@ not available natively in frameworks today.
TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language
Model (LLM) libraries. It provides a Python API consisting of modules to easily build a Transformer Model (LLM) libraries. It provides a Python API consisting of modules to easily build a Transformer
layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8 support. layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8
Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly support. Modules provided by TE internally maintain scaling factors and other values needed for FP8
simplifying mixed precision training for users. training, greatly simplifying mixed precision training for users.
Highlights Highlights
========== ==========
* Easy-to-use modules for building Transformer layers with FP8 support * Easy-to-use modules for building Transformer layers with FP8 support
* Optimizations (e.g. fused kernels) for Transformer models * Optimizations (e.g. fused kernels) for Transformer models
* Support for FP8 on NVIDIA Hopper and NVIDIA Ada GPUs * Support for FP8 on NVIDIA Hopper, Ada, and Blackwell GPUs
* Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later * Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later
Examples Examples
...@@ -149,22 +150,22 @@ Installation ...@@ -149,22 +150,22 @@ Installation
Pre-requisites Pre-requisites
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^
* Linux x86_64 * Linux x86_64
* CUDA 12.0+ for Hopper and CUDA 12.1+ for Ada * CUDA 12.1+ (CUDA 12.8+ for Blackwell)
* NVIDIA Driver supporting CUDA 12.0 or later * NVIDIA Driver supporting CUDA 12.1 or later
* cuDNN 8.1 or later * cuDNN 9.3 or later
* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.
Docker Docker
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^
The quickest way to get started with Transformer Engine is by using Docker images on The quickest way to get started with Transformer Engine is by using Docker images on
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_. For example to use the NGC PyTorch container interactively, `NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_.
For example to use the NGC PyTorch container interactively,
.. code-block:: bash .. code-block:: bash
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3 docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.01-py3
Where 23.10 is the container version. For example, 23.10 for the October 2023 release. Where 25.01 (corresponding to January 2025 release) is the container version.
pip pip
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^
...@@ -174,15 +175,21 @@ To install the latest stable version of Transformer Engine, ...@@ -174,15 +175,21 @@ To install the latest stable version of Transformer Engine,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch). This will automatically detect if any supported deep learning frameworks are installed and build
Transformer Engine support for them. To explicitly specify frameworks, set the environment variable
NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).
Alternatively, the package can be directly installed from `Transformer Engine's PyPI <https://pypi.org/project/transformer-engine/>`_, e.g. Alternatively, the package can be directly installed from
`Transformer Engine's PyPI <https://pypi.org/project/transformer-engine/>`_, e.g.
.. code-block:: bash .. code-block:: bash
pip install transformer_engine[pytorch] pip install transformer_engine[pytorch]
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions. To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be
explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]).
Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX
and PyTorch extensions.
From source From source
^^^^^^^^^^^ ^^^^^^^^^^^
...@@ -190,7 +197,7 @@ From source ...@@ -190,7 +197,7 @@ From source
Compiling with FlashAttention-2 Compiling with FlashAttention-2
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance. Transformer Engine release v0.11.0 added support for FlashAttention-2 in PyTorch for improved performance.
It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue.
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
See LICENSE for license information. See LICENSE for license information.
layer_norm.h fused_rope.h
============ ============
.. doxygenfile:: layer_norm.h .. doxygenfile:: fused_rope.h
...@@ -12,12 +12,16 @@ directly from C/C++, without Python. ...@@ -12,12 +12,16 @@ directly from C/C++, without Python.
.. toctree:: .. toctree::
:caption: Headers :caption: Headers
transformer_engine.h <transformer_engine>
activation.h <activation> activation.h <activation>
cast.h <cast> cast.h <cast>
gemm.h <gemm>
fused_attn.h <fused_attn> fused_attn.h <fused_attn>
layer_norm.h <layer_norm> fused_rope.h <fused_rope>
rmsnorm.h <rmsnorm> gemm.h <gemm>
normalization.h <normalization>
padding.h <padding>
permutation.h <permutation>
recipe.h <recipe>
softmax.h <softmax> softmax.h <softmax>
transformer_engine.h <transformer_engine> swizzle.h <swizzle>
transpose.h <transpose> transpose.h <transpose>
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
normalization.h
===============
.. doxygenfile:: normalization.h
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
See LICENSE for license information. See LICENSE for license information.
rmsnorm.h padding.h
============ =========
.. doxygenfile:: padding.h
.. doxygenfile:: rmsnorm.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
permutation.h
=============
.. doxygenfile:: permutation.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
recipe.h
========
.. doxygenfile:: recipe.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
swizzle.h
=========
.. doxygenfile:: swizzle.h
...@@ -9,3 +9,5 @@ Common API ...@@ -9,3 +9,5 @@ Common API
.. autoapiclass:: transformer_engine.common.recipe.Format .. autoapiclass:: transformer_engine.common.recipe.Format
.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None) .. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None)
.. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3)
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
"* E4M3 - it consists of 1 sign bit, 4 exponent bits and 3 bits of mantissa. It can store values up to +/-448 and `nan`.\n", "* E4M3 - it consists of 1 sign bit, 4 exponent bits and 3 bits of mantissa. It can store values up to +/-448 and `nan`.\n",
"* E5M2 - it consists of 1 sign bit, 5 exponent bits and 2 bits of mantissa. It can store values up to +/-57344, +/- `inf` and `nan`. The tradeoff of the increased dynamic range is lower precision of the stored values.\n", "* E5M2 - it consists of 1 sign bit, 5 exponent bits and 2 bits of mantissa. It can store values up to +/-57344, +/- `inf` and `nan`. The tradeoff of the increased dynamic range is lower precision of the stored values.\n",
"\n", "\n",
"<figure align=\"center\">\n", "<figure align=\"center\" id=\"fig_1\">\n",
"<img src=\"fp8_formats.png\" width=\"60%\">\n", "<img src=\"fp8_formats.png\" width=\"60%\">\n",
"<figcaption> Figure 1: Structure of the floating point datatypes. All of the values shown (in FP16, BF16, FP8 E4M3 and FP8 E5M2) are the closest representations of value 0.3952.</figcaption>\n", "<figcaption> Figure 1: Structure of the floating point datatypes. All of the values shown (in FP16, BF16, FP8 E4M3 and FP8 E5M2) are the closest representations of value 0.3952.</figcaption>\n",
"</figure>\n", "</figure>\n",
...@@ -56,6 +56,50 @@ ...@@ -56,6 +56,50 @@
"As one can see in Figure 3, delayed scaling strategy requires both storing the history of amaxes, but also choosing a recipe for converting that history into the scaling factor used in the next iteration." "As one can see in Figure 3, delayed scaling strategy requires both storing the history of amaxes, but also choosing a recipe for converting that history into the scaling factor used in the next iteration."
] ]
}, },
{
"cell_type": "markdown",
"id": "f03b58ed-71e8-422a-95be-35c1cc60c4e2",
"metadata": {},
"source": [
"## MXFP8 and block scaling\n",
"\n",
"NVIDIA Blackwell architecture introduced support for a new variant of the FP8 format: [MXFP8](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). \n",
"\n",
"### MXFP8 vs FP8\n",
"\n",
"The main difference between \"regular\" FP8 and MXFP8 lies in the granularity of the scaling. In FP8, each tensor has a single FP32 scaling factor, so all values in the tensor need to \"fit\" within the dynamic range of the FP8 datatype. This requires using the less precise E5M2 format to represent some tensors in the network (like gradients).\n",
"\n",
"MXFP8 addresses this by assigning a different scaling factor to each block of 32 [consecutive](#handling-transposes) values. This allows all values to be represented with the E4M3 datatype.\n",
"\n",
"<figure align=\"center\" id=\"fig_4\">\n",
"<img src=\"MXFP8_FP8_comparison_1.png\" width=\"100%\">\n",
"<figcaption> Figure 4: MXFP8 uses multiple scaling factors for a single tensor. The picture shows only 4 values per block for simplicity, but real MXFP8 has 32 values per block.</figcaption>\n",
"</figure>\n",
"\n",
"<figure align=\"center\" id=\"fig_5\">\n",
"<img src=\"MXFP8_FP8_comparison_2.png\" width=\"100%\">\n",
"<figcaption> Figure 5: Due to multiple scaling factors, tensor's dynamic range requirements are reduced and so E4M3 format can be used as far fewer elements get saturated to 0.</figcaption>\n",
"</figure>\n",
"\n",
"The second difference is the datatype used to store the scaling factors. FP8 uses FP32 (E8M23) while MXFP8 uses an 8-bit representation of a power of 2 (E8M0).\n",
"\n",
"<figure align=\"center\" id=\"fig_6\">\n",
"<img src=\"E8M0.png\" width=\"100%\">\n",
"<figcaption> Figure 6: Structure of the E8M0 datatype used for storing scaling factors in MXFP8.</figcaption>\n",
"</figure>\n",
"\n",
"### Handling transposes\n",
"\n",
"The forward and backward passes of linear layers involve multiple matrix multiplications with different reduction dimensions. Blackwell Tensor Cores require MXFP8 data to be \"consecutive\" over the reduction dimension, so MXFP8 training uses non-transposed and transposed MXFP8 tensors at different points. However, while transposing FP8 data is numerically trivial, transposing MXFP8 data requires requantization.\n",
"\n",
"To avoid loss of precision connected with this double quantization, Transformer Engine creates both regular and transposed copies of the tensor from the original high precision input.\n",
"\n",
"<figure align=\"center\" id=\"fig_7\">\n",
"<img src=\"linear_mxfp8.png\" width=\"80%\">\n",
"<figcaption> Figure 7: Linear layer in MXFP8. Calculating both forward and backward pass requires tensors quantized in both directions.</figcaption>\n",
"</figure>"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "cf5e0b0d", "id": "cf5e0b0d",
...@@ -63,11 +107,12 @@ ...@@ -63,11 +107,12 @@
"source": [ "source": [
"## Using FP8 with Transformer Engine\n", "## Using FP8 with Transformer Engine\n",
"\n", "\n",
"Transformer Engine library provides tools enabling easy to use training with FP8 datatype using delayed scaling strategy.\n", "Transformer Engine library provides tools enabling easy to use training with FP8 datatype using FP8 delayed scaling and MXFP8 strategies.\n",
"\n", "\n",
"### FP8 recipe\n", "### FP8 recipe\n",
"\n", "\n",
"[DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe from `transformer_engine.common.recipe` module stores all of the required options for FP8 training - length of the amax history to use for scaling factor computation, FP8 data format etc." "The [DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe from the `transformer_engine.common.recipe` module stores all of the required options for training with FP8 delayed scaling: length of the amax history to use for scaling factor computation, FP8 data format, etc.\n",
"Similarly, [MXFP8BlockScaling](../api/common.rst#transformer_engine.common.recipe.MXFP8BlockScaling) from the same module may be used to enable MXFP8 training."
] ]
}, },
{ {
...@@ -77,10 +122,12 @@ ...@@ -77,10 +122,12 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from transformer_engine.common.recipe import Format, DelayedScaling\n", "from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling\n",
"\n", "\n",
"fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass\n", "fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass\n",
"fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")" "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n",
"mxfp8_format = Format.E4M3 # E4M3 used everywhere\n",
"mxfp8_recipe = MXFP8BlockScaling(fp8_format=mxfp8_format)"
] ]
}, },
{ {
...@@ -341,7 +388,7 @@ ...@@ -341,7 +388,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.13" "version": "3.12.3"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
...@@ -3,10 +3,9 @@ ...@@ -3,10 +3,9 @@
# See LICENSE for license information. # See LICENSE for license information.
import math import math
from typing import Callable, Optional from typing import Optional
import torch import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import DelayedScaling, dist_group_type
def speedometer( def speedometer(
...@@ -204,16 +203,13 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model): ...@@ -204,16 +203,13 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model):
def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"): def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"):
import transformer_engine.pytorch.cpp_extensions as texcpp from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
fp8_type = tex.DType.kFloat8E4M3 if fp8_format == "e4m3" else tex.DType.kFloat8E5M2 fp8_type = tex.DType.kFloat8E4M3 if fp8_format == "e4m3" else tex.DType.kFloat8E5M2
input_type = TE_DType[inp.dtype] scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale
meta = tex.FP8TensorMeta() amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
meta.scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale quantizer = Float8Quantizer(scale=scale, amax=amax_history, fp8_dtype=fp8_type)
meta.scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") / scale ret = quantizer(inp)
meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda") ret = ret.dequantize()
ret = texcpp.cast_to_fp8(inp, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type)
ret = texcpp.cast_from_fp8(ret, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type, input_type)
return ret return ret
...@@ -12,10 +12,9 @@ Prerequisites ...@@ -12,10 +12,9 @@ Prerequisites
.. _driver link: https://www.nvidia.com/drivers .. _driver link: https://www.nvidia.com/drivers
1. Linux x86_64 1. Linux x86_64
2. `CUDA 12.0 <https://developer.nvidia.com/cuda-downloads>`__ 2. `CUDA 12.1+ (12.8+ for Blackwell support) <https://developer.nvidia.com/cuda-downloads>`__
3. |driver link|_ supporting CUDA 12.0 or later. 3. |driver link|_ supporting CUDA 12.1 or later.
4. `cuDNN 8.1 <https://developer.nvidia.com/cudnn>`__ or later. 4. `cuDNN 9.3 <https://developer.nvidia.com/cudnn>`__ or later.
5. For FP8/FP16/BF16 fused attention, `CUDA 12.1 <https://developer.nvidia.com/cuda-downloads>`__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9.1 <https://developer.nvidia.com/cudnn>`__ or later.
If the CUDA Toolkit headers are not available at runtime in a standard If the CUDA Toolkit headers are not available at runtime in a standard
installation path, e.g. within `CUDA_HOME`, set installation path, e.g. within `CUDA_HOME`, set
...@@ -76,7 +75,7 @@ Execute the following command to install the latest development build of Transfo ...@@ -76,7 +75,7 @@ Execute the following command to install the latest development build of Transfo
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable `NVTE_FRAMEWORK` to a comma-separated list (e.g. `NVTE_FRAMEWORK=jax,pytorch`). To only build the framework-agnostic C++ API, set `NVTE_FRAMEWORK=none`. This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable `NVTE_FRAMEWORK` to a comma-separated list (e.g. `NVTE_FRAMEWORK=jax,pytorch`). To only build the framework-agnostic C++ API, set `NVTE_FRAMEWORK=none`.
In order to install a specific PR, execute after changing NNN to the PR number: In order to install a specific PR, execute (after changing NNN to the PR number):
.. code-block:: bash .. code-block:: bash
......
...@@ -164,13 +164,24 @@ class DelayedScaling(Recipe): ...@@ -164,13 +164,24 @@ class DelayedScaling(Recipe):
@dataclass() @dataclass()
class MXFP8BlockScaling(Recipe): class MXFP8BlockScaling(Recipe):
""" """
Use the current scaling factor strategy. Use the MXFP8 scaling factor strategy.
In this strategy, tensors are scaled in blockwise fashion. Each group
of 32 consecutive values is scaled together using their own scaling
factor. The type of the scaling factor is E8M0 (8 bits of exponent,
0 bits of mantissa), equivalent to scaling by a power of 2.
Since the scaling happens in a particular direction (either rowwise
or columnwise), in this recipe the quantized tensor and its transpose
are not numerically equivalent. Due to this, when Transformer Engine
needs both the MXFP8 tensor and its transpose (e.g. to calculate both
forward and backward pass), during the quantization both versions are
computed from the high precision input to avoid double quantization
errors.
Parameters Parameters
---------- ----------
margin : int, default = 0 fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3
Margin for the scaling factor computation.
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID
Controls the FP8 data format used during forward and backward Controls the FP8 data format used during forward and backward
pass. pass.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment