Commit 544dd14b authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Update main branch with TE 2.0 code, update version to 2.1.0.dev0


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent e5369541
......@@ -73,23 +73,3 @@ jobs:
MAX_JOBS: 1
- name: 'Sanity check'
run: python tests/jax/test_sanity_import.py
paddle:
name: 'PaddlePaddle'
runs-on: ubuntu-latest
container:
image: nvcr.io/nvidia/paddlepaddle:24.10-py3
options: --user root
steps:
- name: 'Checkout'
uses: actions/checkout@v3
with:
submodules: recursive
- name: 'Build'
run: |
apt-get update
apt-get install -y libgoogle-glog-dev
pip install . -v
env:
NVTE_FRAMEWORK: paddle
- name: 'Sanity check'
run: python tests/paddle/test_sanity_import.py
......@@ -61,30 +61,3 @@ jobs:
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_jax_lint/test.sh
paddle_cpplint:
name: 'PaddlePaddle C++'
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: 'Lint'
run: |
sudo apt-get update
sudo apt-get install pip -y
export CPP_ONLY=1
export TE_PATH=.
bash ./qa/L0_paddle_lint/test.sh
paddle_pylint:
name: 'PaddlePaddle Python'
runs-on: ubuntu-latest
steps:
- name: 'Checkout'
uses: actions/checkout@v3
- name: 'Lint'
run: |
sudo apt-get update
sudo apt-get install pip -y
pip install paddlepaddle-gpu
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_paddle_lint/test.sh
......@@ -8,7 +8,6 @@
*.nsys-rep
*.ncu-rep
*.sqlite
*.onnx
*.eggs
build/
*.so
......
Subproject commit cc5632eda70bbdac34455c2d94066d27d10e2699
Subproject commit 91b7532f3386768bba4f444ee7672b497f34da8a
......@@ -174,7 +174,7 @@ To install the latest stable version of Transformer Engine,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch,paddle).
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).
Alternatively, the package can be directly installed from `Transformer Engine's PyPI <https://pypi.org/project/transformer-engine/>`_, e.g.
......@@ -182,7 +182,7 @@ Alternatively, the package can be directly installed from `Transformer Engine's
pip install transformer_engine[pytorch]
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions.
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions.
From source
^^^^^^^^^^^
......
......@@ -129,63 +129,6 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
super().run()
self.extensions = all_extensions
paddle_ext = None
if "paddle" in get_frameworks():
for ext in self.extensions:
if "paddle" in ext.name:
paddle_ext = ext
break
# Manually write stub file for Paddle extension
if paddle_ext is not None:
# Load libtransformer_engine.so to avoid linker errors
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
# Source compilation from top-level (--editable)
search_paths = list(Path(__file__).resolve().parent.parent.iterdir())
# Source compilation from top-level
search_paths.extend(list(Path(self.build_lib).iterdir()))
# Dynamically load required_libs.
from transformer_engine.common import _load_cudnn, _load_nvrtc
_load_cudnn()
_load_nvrtc()
else:
# Only during release bdist build for paddlepaddle.
import transformer_engine
search_paths = list(Path(transformer_engine.__path__[0]).iterdir())
del transformer_engine
common_so_path = ""
for path in search_paths:
if path.name.startswith("libtransformer_engine."):
common_so_path = str(path)
assert common_so_path, "Could not find libtransformer_engine"
ctypes.CDLL(common_so_path, mode=ctypes.RTLD_GLOBAL)
# Figure out stub file path
module_name = paddle_ext.name
assert module_name.endswith(
"_pd_"
), "Expected Paddle extension module to end with '_pd_'"
stub_name = module_name[:-4] # remove '_pd_'
stub_path = os.path.join(self.build_lib, "transformer_engine", stub_name + ".py")
Path(stub_path).parent.mkdir(exist_ok=True, parents=True)
# Figure out library name
# Note: This library doesn't actually exist. Paddle
# internally reinserts the '_pd_' suffix.
so_path = self.get_ext_fullpath(module_name)
_, so_ext = os.path.splitext(so_path)
lib_name = stub_name + so_ext
# Write stub file
print(f"Writing Paddle stub for {lib_name} into file {stub_path}")
from paddle.utils.cpp_extension.extension_utils import custom_write_stub
custom_write_stub(lib_name, stub_path)
# Ensure that binaries are not in global package space.
target_dir = install_dir / "transformer_engine"
target_dir.mkdir(exist_ok=True, parents=True)
......@@ -194,16 +137,10 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
self.copy_file(ext, target_dir)
os.remove(ext)
# For paddle, the stub file needs to be copied to the install location.
if paddle_ext is not None:
stub_path = Path(self.build_lib) / "transformer_engine"
for stub in stub_path.glob("transformer_engine_paddle.py"):
self.copy_file(stub, target_dir)
def build_extensions(self):
# BuildExtensions from PyTorch and PaddlePaddle already handle CUDA files correctly
# BuildExtensions from PyTorch already handle CUDA files correctly
# so we don't need to modify their compiler. Only the pybind11 build_ext needs to be fixed.
if "pytorch" not in get_frameworks() and "paddle" not in get_frameworks():
if "pytorch" not in get_frameworks():
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
# extra_compile_args is a dict.
for ext in self.extensions:
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Paddle-paddle related extensions."""
from pathlib import Path
import setuptools
import os
from .utils import cuda_version
import paddle
paddle_version = paddle.__version__.replace(".", "")
def setup_paddle_extension(
csrc_source_files,
csrc_header_files,
common_header_files,
) -> setuptools.Extension:
"""Setup CUDA extension for Paddle support"""
# Source files
csrc_source_files = Path(csrc_source_files)
sources = [
csrc_source_files / "extensions.cpp",
csrc_source_files / "common.cpp",
csrc_source_files / "custom_ops.cu",
]
# Header files
include_dirs = [
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
# Compiler flags
cxx_flags = ["-O3"]
nvcc_flags = [
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
f"-DPADDLE_VERSION={paddle_version}",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
# Version-dependent CUDA options
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA Toolkit version")
else:
if version < (12, 0):
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
nvcc_flags.extend(
(
"--threads",
os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
"-gencode",
"arch=compute_80,code=sm_80",
"-gencode",
"arch=compute_90,code=sm_90",
)
)
# Construct Paddle CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from paddle.utils.cpp_extension import CUDAExtension
ext = CUDAExtension(
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
)
ext.name = "transformer_engine_paddle_pd_"
return ext
......@@ -27,7 +27,6 @@ def setup_pytorch_extension(
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "common.cpp",
csrc_source_files / "ts_fp8_op.cpp",
] + all_files_in_dir(extensions_dir)
# Header files
......
......@@ -190,7 +190,12 @@ def cuda_path() -> Tuple[str, str]:
@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
return os.getenv("NVTE_CUDA_ARCHS", "70;80;89;90")
version = cuda_version()
if os.getenv("NVTE_CUDA_ARCHS") is None:
os.environ["NVTE_CUDA_ARCHS"] = (
"70;80;89;90;100;120" if version >= (12, 8) else "70;80;89;90"
)
return os.getenv("NVTE_CUDA_ARCHS")
def cuda_version() -> Tuple[int, ...]:
......@@ -211,7 +216,7 @@ def cuda_version() -> Tuple[int, ...]:
def get_frameworks() -> List[str]:
"""DL frameworks to build support for"""
_frameworks: List[str] = []
supported_frameworks = ["pytorch", "jax", "paddle"]
supported_frameworks = ["pytorch", "jax"]
# Check environment variable
if os.getenv("NVTE_FRAMEWORK"):
......@@ -237,12 +242,6 @@ def get_frameworks() -> List[str]:
pass
else:
_frameworks.append("jax")
try:
import paddle
except ImportError:
pass
else:
_frameworks.append("paddle")
# Special framework names
if "all" in _frameworks:
......@@ -311,7 +310,6 @@ def uninstall_te_wheel_packages():
"-y",
"transformer_engine_cu12",
"transformer_engine_torch",
"transformer_engine_paddle",
"transformer_engine_jax",
]
)
......@@ -9,7 +9,6 @@ BUILD_METAPACKAGE=${2:-true}
BUILD_COMMON=${3:-true}
BUILD_PYTORCH=${4:-true}
BUILD_JAX=${5:-true}
BUILD_PADDLE=${6:-true}
export NVTE_RELEASE_BUILD=1
export TARGET_BRANCH=${TARGET_BRANCH:-}
......@@ -63,38 +62,3 @@ if $BUILD_JAX ; then
/opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
cp dist/* /wheelhouse/
fi
if $BUILD_PADDLE ; then
if [ "$PLATFORM" == "manylinux_2_28_x86_64" ] ; then
dnf -y remove --allowerasing cudnn9-cuda-12
dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64
cd /TransformerEngine/transformer_engine/paddle
/opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt
/opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
/opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt
/opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
/opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt
/opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
/opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt
/opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
/opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt
/opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
mv dist/* /wheelhouse/
fi
fi
......@@ -8,4 +8,4 @@ Common API
.. autoapiclass:: transformer_engine.common.recipe.Format
.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))
.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None)
......@@ -10,4 +10,3 @@ Framework-specific API
pytorch
jax
paddle
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
paddle
======
.. autoapiclass:: transformer_engine.paddle.Linear(in_features, out_features, **kwargs)
:members: forward
.. autoapiclass:: transformer_engine.paddle.LayerNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.paddle.LayerNormLinear(in_features, out_features, eps=1e-5, **kwargs)
:members: forward
.. autoapiclass:: transformer_engine.paddle.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, **kwargs)
:members: forward
.. autoapiclass:: transformer_engine.paddle.FusedScaleMaskSoftmax(attn_mask_type, mask_func, **kwargs)
:members: forward
.. autoapiclass:: transformer_engine.paddle.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
:members: forward
.. autoapiclass:: transformer_engine.paddle.MultiHeadAttention(hidden_size, num_attention_heads, **kwargs)
:members: forward
.. autoapiclass:: transformer_engine.paddle.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward
.. autoapifunction:: transformer_engine.paddle.fp8_autocast
.. autoapifunction:: transformer_engine.paddle.recompute
......@@ -42,8 +42,6 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.checkpoint
.. autoapifunction:: transformer_engine.pytorch.onnx_export
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
......
......@@ -14,11 +14,10 @@
"<figcaption> Figure 1: Dot product attention. </figcaption>\n",
"</figure>\n",
"\n",
"[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). The API for each framework is\n",
"[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in two frameworks, [PyTorch](https://github.com/pytorch/pytorch) and [JAX](https://github.com/google/jax). The API for each framework is\n",
"\n",
"- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)\n",
"- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)\n",
"- [transformer_engine.paddle.DotProductAttention](../../api/paddle.rst#transformer_engine.paddle.DotProductAttention)"
"- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)"
]
},
{
......@@ -56,15 +55,6 @@
" <tr>\n",
" <td>JAX-native attention (`_UnfusedDotProductAttention`)</td>\n",
" </tr>\n",
" <tr>\n",
" <td rowspan=\"2\"> PaddlePaddle</td>\n",
" <td> cuDNN attention (`_te_forward`) </td>\n",
" <td rowspan=\"2\"> [transformer_engine.paddle.layer.attention](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/paddle/layer/attention.py)\n",
" </td> \n",
" </tr>\n",
" <tr>\n",
" <td>PaddlePaddle-native attention (`_pd_forward`)</td>\n",
" </tr>\n",
" \n",
"</table>"
]
......@@ -87,7 +77,7 @@
"<div class=\"alert alert-info\">\n",
"<b>Note:</b> \n",
" \n",
"Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n",
"Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch and JAX, are both based on the flash algorithm.\n",
"</div>\n"
]
},
......@@ -102,13 +92,13 @@
"\n",
"The flash-attention backend supports `flash-attn`'s features as well as a few extra functionalities to facilitate the use of `flash-attn`, such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask use cases. Please see `transformer_engine.pytorch.attention.FlashAttention` for details.\n",
"\n",
"The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.10, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n",
"The `flash-attn` dependency is regularly updated in Transformer Engine. As of v2.0, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n",
"\n",
"To understand `flash-attn`'s performance, please refer to their benchmarks [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n",
"\n",
"### 1.3 cuDNN Attention\n",
"\n",
"The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n",
"The cuDNN attention backend, available in PyTorch and JAX, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n",
"\n",
"<table class=\"docutils align-default\">\n",
" <tr>\n",
......@@ -153,9 +143,9 @@
" </tr>\n",
"</table>\n",
"\n",
"The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.10, cuDNN 9.3 and `flash-attn` 2.4.2,\n",
"The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 2.0, cuDNN 9.3 and `flash-attn` 2.4.2,\n",
"\n",
"- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch, JAX and PaddlePaddle.\n",
"- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch and JAX.\n",
"- flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).\n",
"- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three formats without transposes (see Section 3.1 for more details).\n",
"- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n",
......@@ -244,10 +234,6 @@
" <td>JAX</td>\n",
" <td>cuDNN attention > JAX-native attention</td>\n",
" </tr>\n",
" <tr>\n",
" <td> PaddlePaddle</td>\n",
" <td> cuDNN attention > PaddlePaddle-native attention </td>\n",
" </tr>\n",
"</table>"
]
},
......@@ -266,7 +252,7 @@
"<div class=\"alert alert-info\">\n",
"<b>Note:</b>\n",
" \n",
"These flags are supported in PyTorch only as of Transformer Engine 1.10. JAX and PaddlePaddle support is expected to be added in the future.\n",
"These flags are supported in PyTorch only as of Transformer Engine 2.0. JAX support is expected to be added in the future.\n",
"</div>"
]
},
......@@ -382,7 +368,7 @@
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
" \n",
"Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_FUSED_ATTN</code>, <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code> and <code>NVTE_ALLOW_NONDETERMINISTIC_ALGO</code> are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n",
"Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_FUSED_ATTN</code>, <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code> and <code>NVTE_ALLOW_NONDETERMINISTIC_ALGO</code> are only supported in PyTorch, and will be added to JAX in the future.\n",
"</div>\n",
"\n",
"### 2.3 Example Tests\n",
......@@ -399,7 +385,7 @@
"source": [
"## 3. Backend Support\n",
"\n",
"Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.10, Transformer Engine's attention backends have the following support matrix.\n",
"Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v2.0, Transformer Engine's attention backends have the following support matrix.\n",
"\n",
"| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n",
"| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n",
......@@ -442,7 +428,7 @@
"**qkv_layout=thd_thd_thd:**\n",
"`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n",
"\n",
"As of v1.10, Transformer Engine has the following support matrix.\n",
"As of v2.0, Transformer Engine has the following support matrix.\n",
"\n",
"<table class=\"docutils align-default\">\n",
" <tr>\n",
......@@ -462,13 +448,13 @@
" </tr>\n",
" <tr>\n",
" <td>\n",
" JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n",
" JAX: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n",
" </td> \n",
" </tr>\n",
" <tr>\n",
" <td>Framework-native attention</td>\n",
" <td>`bshd`, `sbhd`</td>\n",
" <td>PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts</td>\n",
" <td>PyTorch, JAX: 2 formats, i.e. 10 layouts</td>\n",
" </tr>\n",
"</table>\n",
"\n",
......@@ -492,7 +478,7 @@
"\n",
"- `no_mask`, `padding`, `causal`, `causal_bottom_right`, `padding_causal`, `padding_causal_bottom_right`, `arbitrary`\n",
"\n",
"Different backends offer different support for attention mask. As of Transformer Engine 1.10,\n",
"Different backends offer different support for attention mask. As of Transformer Engine 2.0,\n",
"\n",
"<table class=\"docutils align-default\">\n",
" <tr>\n",
......@@ -512,21 +498,21 @@
" </tr>\n",
" <tr>\n",
" <td>Framework-native attention</td>\n",
" <td><li>All (PyTorch)</li><li>`no_mask`, `causal`, `padding` (Jax, PaddlePaddle)</li></td>\n",
" <td><li>All (PyTorch)</li><li>`no_mask`, `causal`, `padding` (Jax)</li></td>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>\n",
"\n",
"**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.10, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n",
"**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 2.0, there are two options to do so in PyTorch and one in JAX.\n",
"\n",
"* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
" - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
" - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
"\n",
"\n",
"* JAX and PaddlePaddle: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"* JAX: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"\n",
"**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
"\n",
......@@ -566,7 +552,7 @@
"\n",
"### 3.3 Attention Bias\n",
"\n",
"Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.10, their support matrix is as follows.\n",
"Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 2.0, their support matrix is as follows.\n",
"\n",
"<table class=\"docutils align-default\">\n",
" <tr>\n",
......@@ -591,7 +577,7 @@
" <td>cuDNN 8.9.6+: sm90</td>\n",
" </tr>\n",
" <tr>\n",
" <td>JAX, PaddlePaddle: `no_bias`, `post_scale_bias`</td> \n",
" <td>JAX: `no_bias`, `post_scale_bias`</td> \n",
" <td>ALiBi slopes: FP32</td>\n",
" <td>cuDNN 9.0+: sm80+</td>\n",
" </tr>\n",
......@@ -620,7 +606,7 @@
"\n",
"A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n",
"\n",
"Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.10. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n",
"Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v2.0. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n",
"\n",
"- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n",
"\n",
......
......@@ -37,7 +37,7 @@ Transformer Engine can be directly installed from `our PyPI <https://pypi.org/pr
pip install transformer_engine[pytorch]
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions.
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions.
pip - from GitHub
-----------------------
......
# Examples
We provide a variety of examples for deep learning frameworks including [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/jax-ml/jax), and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle).
We provide a variety of examples for deep learning frameworks including [PyTorch](https://github.com/pytorch/pytorch) and [JAX](https://github.com/jax-ml/jax).
Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/TransformerEngine/tree/main/docs/examples) and a selection of [third-party examples](#third-party). Please be aware that these third-party examples might need specific, older versions of dependencies to function properly.
# PyTorch
......@@ -35,9 +35,6 @@ Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/Tr
- Multiprocessing with Model Parallelism: Multiprocessing for model parallelism, including multi-node support and hardware affinity setup.
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/mnist)
# PaddlePaddle
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/paddle/mnist)
# Third party
- [Hugging Face Accelerate + TE](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8/transformer_engine)
- Scripts for training with Accelerate and TE. Supports single GPU, and multi-GPU via DDP, FSDP, and DeepSpeed ZeRO 1-3.
# Basic MNIST Example
```bash
python test_single_gpu_mnist.py
python test_single_gpu_mnist.py --use-te # Linear layers from TransformerEngine
python test_single_gpu_mnist.py --use-te --use-fp8 # FP8 + TransformerEngine for Linear layers
```
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""MNIST example of Transformer Engine Paddle"""
import argparse
import os
import unittest
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.vision.transforms import Normalize
from paddle.io import DataLoader
from paddle.vision.datasets import MNIST
from paddle.metric import Accuracy
import transformer_engine.paddle as te
from transformer_engine.paddle.fp8 import is_fp8_available
class Net(nn.Layer):
"""Simple network used to train on MNIST"""
def __init__(self, use_te=False):
super().__init__()
self.conv1 = nn.Conv2D(1, 32, 3, 1)
self.conv2 = nn.Conv2D(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
if use_te:
self.fc1 = te.Linear(9216, 128)
self.fc2 = te.Linear(128, 16)
else:
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 16)
self.fc3 = nn.Linear(16, 10)
def forward(self, x):
"""FWD"""
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = paddle.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
x = self.fc3(x)
return x
def train(args, model, train_loader, optimizer, epoch, use_fp8):
"""Training function."""
model.train()
losses = []
for batch_id, (data, labels) in enumerate(train_loader):
with paddle.amp.auto_cast(
dtype="bfloat16", level="O2"
): # pylint: disable=not-context-manager
with te.fp8_autocast(enabled=use_fp8):
outputs = model(data)
loss = F.cross_entropy(outputs, labels)
losses.append(loss.item())
loss.backward()
optimizer.step()
optimizer.clear_gradients()
if batch_id % args.log_interval == 0:
print(
f"Train Epoch: {epoch} "
f"[{batch_id * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_id / len(train_loader):.0f}%)]\t"
f"Loss: {loss.item():.6f}"
)
if args.dry_run:
return loss.item()
avg_loss = sum(losses) / len(losses)
print(f"Train Epoch: {epoch}, Average Loss: {avg_loss}")
return avg_loss
def evaluate(model, test_loader, epoch, use_fp8):
"""Testing function."""
model.eval()
metric = Accuracy()
metric.reset()
with paddle.no_grad():
for data, labels in test_loader:
with paddle.amp.auto_cast(
dtype="bfloat16", level="O2"
): # pylint: disable=not-context-manager
with te.fp8_autocast(enabled=use_fp8):
outputs = model(data)
acc = metric.compute(outputs, labels)
metric.update(acc)
print(f"Epoch[{epoch}] - accuracy: {metric.accumulate():.6f}")
return metric.accumulate()
def calibrate(model, test_loader):
"""Calibration function."""
model.eval()
with paddle.no_grad():
for data, _ in test_loader:
with paddle.amp.auto_cast(
dtype="bfloat16", level="O2"
): # pylint: disable=not-context-manager
with te.fp8_autocast(enabled=False, calibrating=True):
_ = model(data)
def mnist_parser(args):
"""Parse training settings"""
parser = argparse.ArgumentParser(description="Paddle MNIST Example")
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1000,
metavar="N",
help="input batch size for testing (default: 1000)",
)
parser.add_argument(
"--epochs",
type=int,
default=14,
metavar="N",
help="number of epochs to train (default: 14)",
)
parser.add_argument(
"--lr",
type=float,
default=0.001,
metavar="LR",
help="learning rate (default: 0.001)",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument(
"--save-model",
action="store_true",
default=False,
help="For Saving the current Model",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument(
"--use-fp8",
action="store_true",
default=False,
help=(
"Use FP8 for inference and training without recalibration. "
"It also enables Transformer Engine implicitly."
),
)
parser.add_argument(
"--use-fp8-infer",
action="store_true",
default=False,
help=(
"Use FP8 for inference only. If not using FP8 for training, "
"calibration is performed for FP8 infernece."
),
)
parser.add_argument(
"--use-te", action="store_true", default=False, help="Use Transformer Engine"
)
args = parser.parse_args(args)
return args
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
paddle.seed(args.seed)
# Load MNIST dataset
transform = Normalize(mean=[127.5], std=[127.5], data_format="CHW")
train_dataset = MNIST(mode="train", transform=transform)
val_dataset = MNIST(mode="test", transform=transform)
# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.test_batch_size)
# Define model and optimizer
model = Net(use_te=args.use_te)
optimizer = paddle.optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())
# Cast model to BF16
model = paddle.amp.decorate(models=model, level="O2", dtype="bfloat16")
for epoch in range(1, args.epochs + 1):
loss = train(args, model, train_loader, optimizer, epoch, args.use_fp8)
acc = evaluate(model, val_loader, epoch, args.use_fp8)
if args.use_fp8_infer and not args.use_fp8:
calibrate(model, val_loader)
if args.save_model or args.use_fp8_infer:
paddle.save(model.state_dict(), "mnist_cnn.pdparams")
print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8))
weights = paddle.load("mnist_cnn.pdparams")
model.set_state_dict(weights)
acc = evaluate(model, val_loader, 0, args.use_fp8)
return loss, acc
class TestMNIST(unittest.TestCase):
"""MNIST unittests"""
gpu_has_fp8, reason = is_fp8_available()
@classmethod
def setUpClass(cls):
"""Run MNIST without Transformer Engine"""
cls.args = mnist_parser(["--epochs", "5"])
@staticmethod
def verify(actual):
"""Check If loss and accuracy match target"""
desired_traing_loss = 0.1
desired_test_accuracy = 0.98
assert actual[0] < desired_traing_loss
assert actual[1] > desired_test_accuracy
@unittest.skipIf(
paddle.device.cuda.get_device_capability() < (8, 0),
"BF16 MNIST example requires Ampere+ GPU",
)
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
self.args.use_te = True
self.args.use_fp8 = False
self.args.save_model = True
actual = train_and_evaluate(self.args)
if os.path.exists("mnist_cnn.pdparams"):
os.remove("mnist_cnn.pdparams")
self.verify(actual)
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_te = True
self.args.use_fp8 = True
self.args.save_model = True
actual = train_and_evaluate(self.args)
if os.path.exists("mnist_cnn.pdparams"):
os.remove("mnist_cnn.pdparams")
self.verify(actual)
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8_calibration(self):
"""Test Transformer Engine with FP8 calibration"""
self.args.use_te = True
self.args.use_fp8 = False
self.args.use_fp8_infer = True
actual = train_and_evaluate(self.args)
if os.path.exists("mnist_cnn.pdparams"):
os.remove("mnist_cnn.pdparams")
self.verify(actual)
if __name__ == "__main__":
train_and_evaluate(mnist_parser(None))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment