diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 964e71fa8c0f526ab4cb15e302ddc54170312cc5..4be7a30a86a6e8138527ae35d784d660480a6926 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -73,23 +73,3 @@ jobs:
MAX_JOBS: 1
- name: 'Sanity check'
run: python tests/jax/test_sanity_import.py
- paddle:
- name: 'PaddlePaddle'
- runs-on: ubuntu-latest
- container:
- image: nvcr.io/nvidia/paddlepaddle:24.10-py3
- options: --user root
- steps:
- - name: 'Checkout'
- uses: actions/checkout@v3
- with:
- submodules: recursive
- - name: 'Build'
- run: |
- apt-get update
- apt-get install -y libgoogle-glog-dev
- pip install . -v
- env:
- NVTE_FRAMEWORK: paddle
- - name: 'Sanity check'
- run: python tests/paddle/test_sanity_import.py
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index f98fc9aa3a56f47fbce7f4f995bf6e6d6ca9f720..ee6433d484bb74778198582b4638ad0cfb8d1104 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -61,30 +61,3 @@ jobs:
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_jax_lint/test.sh
- paddle_cpplint:
- name: 'PaddlePaddle C++'
- runs-on: ubuntu-latest
- steps:
- - name: Checkout
- uses: actions/checkout@v3
- - name: 'Lint'
- run: |
- sudo apt-get update
- sudo apt-get install pip -y
- export CPP_ONLY=1
- export TE_PATH=.
- bash ./qa/L0_paddle_lint/test.sh
- paddle_pylint:
- name: 'PaddlePaddle Python'
- runs-on: ubuntu-latest
- steps:
- - name: 'Checkout'
- uses: actions/checkout@v3
- - name: 'Lint'
- run: |
- sudo apt-get update
- sudo apt-get install pip -y
- pip install paddlepaddle-gpu
- export PYTHON_ONLY=1
- export TE_PATH=.
- bash ./qa/L0_paddle_lint/test.sh
diff --git a/.gitignore b/.gitignore
index 9b61454e216f1c5b44f254011493de616e8d1e8d..f491b21f43bcbcde70f7e38536477a420c5aec32 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,7 +8,6 @@
*.nsys-rep
*.ncu-rep
*.sqlite
-*.onnx
*.eggs
build/
*.so
diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend
index cc5632eda70bbdac34455c2d94066d27d10e2699..91b7532f3386768bba4f444ee7672b497f34da8a 160000
--- a/3rdparty/cudnn-frontend
+++ b/3rdparty/cudnn-frontend
@@ -1 +1 @@
-Subproject commit cc5632eda70bbdac34455c2d94066d27d10e2699
+Subproject commit 91b7532f3386768bba4f444ee7672b497f34da8a
diff --git a/README.rst b/README.rst
index fbcf05f3c9b6c647d0d66dfb870404516802756b..8fea8c9d94916ee0fe257ed72f3ee8d72f1f3f94 100644
--- a/README.rst
+++ b/README.rst
@@ -174,7 +174,7 @@ To install the latest stable version of Transformer Engine,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
-This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch,paddle).
+This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).
Alternatively, the package can be directly installed from `Transformer Engine's PyPI `_, e.g.
@@ -182,7 +182,7 @@ Alternatively, the package can be directly installed from `Transformer Engine's
pip install transformer_engine[pytorch]
-To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions.
+To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions.
From source
^^^^^^^^^^^
diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt
index 809a0327d8dcbf6c8b23f314b8e8b58627d2f5e7..eb5820cd2d6fda4283484b8ba084fffd7356d81c 100644
--- a/build_tools/VERSION.txt
+++ b/build_tools/VERSION.txt
@@ -1 +1 @@
-1.14.0.dev0
+2.1.0.dev0
diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py
index 5744439c1b77fe99b3ced78681457fca4199bce4..a3243d087bfac234507345c02667c9f95fb52006 100644
--- a/build_tools/build_ext.py
+++ b/build_tools/build_ext.py
@@ -129,63 +129,6 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
super().run()
self.extensions = all_extensions
- paddle_ext = None
- if "paddle" in get_frameworks():
- for ext in self.extensions:
- if "paddle" in ext.name:
- paddle_ext = ext
- break
-
- # Manually write stub file for Paddle extension
- if paddle_ext is not None:
- # Load libtransformer_engine.so to avoid linker errors
- if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
- # Source compilation from top-level (--editable)
- search_paths = list(Path(__file__).resolve().parent.parent.iterdir())
- # Source compilation from top-level
- search_paths.extend(list(Path(self.build_lib).iterdir()))
-
- # Dynamically load required_libs.
- from transformer_engine.common import _load_cudnn, _load_nvrtc
-
- _load_cudnn()
- _load_nvrtc()
- else:
- # Only during release bdist build for paddlepaddle.
- import transformer_engine
-
- search_paths = list(Path(transformer_engine.__path__[0]).iterdir())
- del transformer_engine
-
- common_so_path = ""
- for path in search_paths:
- if path.name.startswith("libtransformer_engine."):
- common_so_path = str(path)
- assert common_so_path, "Could not find libtransformer_engine"
- ctypes.CDLL(common_so_path, mode=ctypes.RTLD_GLOBAL)
-
- # Figure out stub file path
- module_name = paddle_ext.name
- assert module_name.endswith(
- "_pd_"
- ), "Expected Paddle extension module to end with '_pd_'"
- stub_name = module_name[:-4] # remove '_pd_'
- stub_path = os.path.join(self.build_lib, "transformer_engine", stub_name + ".py")
- Path(stub_path).parent.mkdir(exist_ok=True, parents=True)
-
- # Figure out library name
- # Note: This library doesn't actually exist. Paddle
- # internally reinserts the '_pd_' suffix.
- so_path = self.get_ext_fullpath(module_name)
- _, so_ext = os.path.splitext(so_path)
- lib_name = stub_name + so_ext
-
- # Write stub file
- print(f"Writing Paddle stub for {lib_name} into file {stub_path}")
- from paddle.utils.cpp_extension.extension_utils import custom_write_stub
-
- custom_write_stub(lib_name, stub_path)
-
# Ensure that binaries are not in global package space.
target_dir = install_dir / "transformer_engine"
target_dir.mkdir(exist_ok=True, parents=True)
@@ -194,16 +137,10 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
self.copy_file(ext, target_dir)
os.remove(ext)
- # For paddle, the stub file needs to be copied to the install location.
- if paddle_ext is not None:
- stub_path = Path(self.build_lib) / "transformer_engine"
- for stub in stub_path.glob("transformer_engine_paddle.py"):
- self.copy_file(stub, target_dir)
-
def build_extensions(self):
- # BuildExtensions from PyTorch and PaddlePaddle already handle CUDA files correctly
+ # BuildExtensions from PyTorch already handle CUDA files correctly
# so we don't need to modify their compiler. Only the pybind11 build_ext needs to be fixed.
- if "pytorch" not in get_frameworks() and "paddle" not in get_frameworks():
+ if "pytorch" not in get_frameworks():
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
# extra_compile_args is a dict.
for ext in self.extensions:
diff --git a/build_tools/paddle.py b/build_tools/paddle.py
deleted file mode 100644
index f0fcdb8f250ec474321e237bdeb8997ff72f65fa..0000000000000000000000000000000000000000
--- a/build_tools/paddle.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# See LICENSE for license information.
-
-"""Paddle-paddle related extensions."""
-from pathlib import Path
-
-import setuptools
-import os
-
-from .utils import cuda_version
-
-import paddle
-
-paddle_version = paddle.__version__.replace(".", "")
-
-
-def setup_paddle_extension(
- csrc_source_files,
- csrc_header_files,
- common_header_files,
-) -> setuptools.Extension:
- """Setup CUDA extension for Paddle support"""
-
- # Source files
- csrc_source_files = Path(csrc_source_files)
- sources = [
- csrc_source_files / "extensions.cpp",
- csrc_source_files / "common.cpp",
- csrc_source_files / "custom_ops.cu",
- ]
-
- # Header files
- include_dirs = [
- common_header_files,
- common_header_files / "common",
- common_header_files / "common" / "include",
- csrc_header_files,
- ]
-
- # Compiler flags
- cxx_flags = ["-O3"]
- nvcc_flags = [
- "-O3",
- "-gencode",
- "arch=compute_70,code=sm_70",
- "-U__CUDA_NO_HALF_OPERATORS__",
- "-U__CUDA_NO_HALF_CONVERSIONS__",
- "-U__CUDA_NO_BFLOAT16_OPERATORS__",
- "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
- "-U__CUDA_NO_BFLOAT162_OPERATORS__",
- "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
- f"-DPADDLE_VERSION={paddle_version}",
- "--expt-relaxed-constexpr",
- "--expt-extended-lambda",
- "--use_fast_math",
- ]
-
- # Version-dependent CUDA options
- try:
- version = cuda_version()
- except FileNotFoundError:
- print("Could not determine CUDA Toolkit version")
- else:
- if version < (12, 0):
- raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
- nvcc_flags.extend(
- (
- "--threads",
- os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
- "-gencode",
- "arch=compute_80,code=sm_80",
- "-gencode",
- "arch=compute_90,code=sm_90",
- )
- )
-
- # Construct Paddle CUDA extension
- sources = [str(path) for path in sources]
- include_dirs = [str(path) for path in include_dirs]
- from paddle.utils.cpp_extension import CUDAExtension
-
- ext = CUDAExtension(
- sources=sources,
- include_dirs=include_dirs,
- extra_compile_args={
- "cxx": cxx_flags,
- "nvcc": nvcc_flags,
- },
- )
- ext.name = "transformer_engine_paddle_pd_"
- return ext
diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py
index f060e99dff75abb9cbcb1f807b3e2516923ab5f1..b8501e1008ddd13f7fcbc90e48ed489a15b9667c 100644
--- a/build_tools/pytorch.py
+++ b/build_tools/pytorch.py
@@ -27,7 +27,6 @@ def setup_pytorch_extension(
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "common.cpp",
- csrc_source_files / "ts_fp8_op.cpp",
] + all_files_in_dir(extensions_dir)
# Header files
diff --git a/build_tools/utils.py b/build_tools/utils.py
index f2a420068535cb01e21f24ad4501395ff0b56e89..723f2f200cff977f5446b083c7b3b36eb938c393 100644
--- a/build_tools/utils.py
+++ b/build_tools/utils.py
@@ -190,7 +190,12 @@ def cuda_path() -> Tuple[str, str]:
@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
- return os.getenv("NVTE_CUDA_ARCHS", "70;80;89;90")
+ version = cuda_version()
+ if os.getenv("NVTE_CUDA_ARCHS") is None:
+ os.environ["NVTE_CUDA_ARCHS"] = (
+ "70;80;89;90;100;120" if version >= (12, 8) else "70;80;89;90"
+ )
+ return os.getenv("NVTE_CUDA_ARCHS")
def cuda_version() -> Tuple[int, ...]:
@@ -211,7 +216,7 @@ def cuda_version() -> Tuple[int, ...]:
def get_frameworks() -> List[str]:
"""DL frameworks to build support for"""
_frameworks: List[str] = []
- supported_frameworks = ["pytorch", "jax", "paddle"]
+ supported_frameworks = ["pytorch", "jax"]
# Check environment variable
if os.getenv("NVTE_FRAMEWORK"):
@@ -237,12 +242,6 @@ def get_frameworks() -> List[str]:
pass
else:
_frameworks.append("jax")
- try:
- import paddle
- except ImportError:
- pass
- else:
- _frameworks.append("paddle")
# Special framework names
if "all" in _frameworks:
@@ -311,7 +310,6 @@ def uninstall_te_wheel_packages():
"-y",
"transformer_engine_cu12",
"transformer_engine_torch",
- "transformer_engine_paddle",
"transformer_engine_jax",
]
)
diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh
index ceebe626f4fb1e9f9e1c04e0adcdfb3485b7de84..9acb22aee659e09c58f582c6f12035c320fa09d2 100644
--- a/build_tools/wheel_utils/build_wheels.sh
+++ b/build_tools/wheel_utils/build_wheels.sh
@@ -9,7 +9,6 @@ BUILD_METAPACKAGE=${2:-true}
BUILD_COMMON=${3:-true}
BUILD_PYTORCH=${4:-true}
BUILD_JAX=${5:-true}
-BUILD_PADDLE=${6:-true}
export NVTE_RELEASE_BUILD=1
export TARGET_BRANCH=${TARGET_BRANCH:-}
@@ -63,38 +62,3 @@ if $BUILD_JAX ; then
/opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
cp dist/* /wheelhouse/
fi
-
-if $BUILD_PADDLE ; then
- if [ "$PLATFORM" == "manylinux_2_28_x86_64" ] ; then
- dnf -y remove --allowerasing cudnn9-cuda-12
- dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64
- cd /TransformerEngine/transformer_engine/paddle
-
- /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl --no-deps
- /opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1
- /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt
- /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
-
- /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl --no-deps
- /opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1
- /opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt
- /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
-
- /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl --no-deps
- /opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1
- /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt
- /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
-
- /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl --no-deps
- /opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1
- /opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt
- /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
-
- /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl --no-deps
- /opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1
- /opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt
- /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
-
- mv dist/* /wheelhouse/
- fi
-fi
diff --git a/docs/api/common.rst b/docs/api/common.rst
index 85201aee5d7ba4ca3205ffd29467dd905a4a5d3f..5e0a660ae647eeb68aa2a7d341cefaeeab826e3a 100644
--- a/docs/api/common.rst
+++ b/docs/api/common.rst
@@ -8,4 +8,4 @@ Common API
.. autoapiclass:: transformer_engine.common.recipe.Format
-.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))
+.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None)
diff --git a/docs/api/framework.rst b/docs/api/framework.rst
index acd54fe3b125330b330b3a4a133e01be9f98bcd6..0ac1a0e34e7ee553258151b664514eac45741da9 100644
--- a/docs/api/framework.rst
+++ b/docs/api/framework.rst
@@ -10,4 +10,3 @@ Framework-specific API
pytorch
jax
- paddle
diff --git a/docs/api/paddle.rst b/docs/api/paddle.rst
deleted file mode 100644
index 3b3ecf55c6dfecedb8bbdc57eb69515f69cdb526..0000000000000000000000000000000000000000
--- a/docs/api/paddle.rst
+++ /dev/null
@@ -1,34 +0,0 @@
-..
- Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-
- See LICENSE for license information.
-
-paddle
-======
-
-.. autoapiclass:: transformer_engine.paddle.Linear(in_features, out_features, **kwargs)
- :members: forward
-
-.. autoapiclass:: transformer_engine.paddle.LayerNorm(hidden_size, eps=1e-5, **kwargs)
-
-.. autoapiclass:: transformer_engine.paddle.LayerNormLinear(in_features, out_features, eps=1e-5, **kwargs)
- :members: forward
-
-.. autoapiclass:: transformer_engine.paddle.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, **kwargs)
- :members: forward
-
-.. autoapiclass:: transformer_engine.paddle.FusedScaleMaskSoftmax(attn_mask_type, mask_func, **kwargs)
- :members: forward
-
-.. autoapiclass:: transformer_engine.paddle.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
- :members: forward
-
-.. autoapiclass:: transformer_engine.paddle.MultiHeadAttention(hidden_size, num_attention_heads, **kwargs)
- :members: forward
-
-.. autoapiclass:: transformer_engine.paddle.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
- :members: forward
-
-.. autoapifunction:: transformer_engine.paddle.fp8_autocast
-
-.. autoapifunction:: transformer_engine.paddle.recompute
diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst
index 43001feeb3990379d15c74ed474d014e005ef765..6d5fe6761d4e1fe580afa5aeee66e6ceabd17c79 100644
--- a/docs/api/pytorch.rst
+++ b/docs/api/pytorch.rst
@@ -42,8 +42,6 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.checkpoint
-.. autoapifunction:: transformer_engine.pytorch.onnx_export
-
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb
index 27017b4773fc88b3eba4ffe1718402e90b4e96c0..16a3b05466ed417e8c277ab786c868a2a10389a1 100644
--- a/docs/examples/attention/attention.ipynb
+++ b/docs/examples/attention/attention.ipynb
@@ -14,11 +14,10 @@
" Figure 1: Dot product attention. \n",
"\n",
"\n",
- "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). The API for each framework is\n",
+ "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in two frameworks, [PyTorch](https://github.com/pytorch/pytorch) and [JAX](https://github.com/google/jax). The API for each framework is\n",
"\n",
"- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)\n",
- "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)\n",
- "- [transformer_engine.paddle.DotProductAttention](../../api/paddle.rst#transformer_engine.paddle.DotProductAttention)"
+ "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)"
]
},
{
@@ -56,15 +55,6 @@
"
\n",
" | JAX-native attention (`_UnfusedDotProductAttention`) | \n",
"
\n",
- " \n",
- " | PaddlePaddle | \n",
- " cuDNN attention (`_te_forward`) | \n",
- " [transformer_engine.paddle.layer.attention](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/paddle/layer/attention.py)\n",
- " | \n",
- "
\n",
- " \n",
- " | PaddlePaddle-native attention (`_pd_forward`) | \n",
- "
\n",
" \n",
""
]
@@ -87,7 +77,7 @@
"\n",
"Note: \n",
" \n",
- "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n",
+ "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch and JAX, are both based on the flash algorithm.\n",
"
\n"
]
},
@@ -102,13 +92,13 @@
"\n",
"The flash-attention backend supports `flash-attn`'s features as well as a few extra functionalities to facilitate the use of `flash-attn`, such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask use cases. Please see `transformer_engine.pytorch.attention.FlashAttention` for details.\n",
"\n",
- "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.10, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n",
+ "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v2.0, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n",
"\n",
"To understand `flash-attn`'s performance, please refer to their benchmarks [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n",
"\n",
"### 1.3 cuDNN Attention\n",
"\n",
- "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n",
+ "The cuDNN attention backend, available in PyTorch and JAX, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n",
"\n",
"\n",
" \n",
@@ -153,9 +143,9 @@
"
\n",
"
\n",
"\n",
- "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.10, cuDNN 9.3 and `flash-attn` 2.4.2,\n",
+ "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 2.0, cuDNN 9.3 and `flash-attn` 2.4.2,\n",
"\n",
- "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch, JAX and PaddlePaddle.\n",
+ "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch and JAX.\n",
"- flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).\n",
"- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three formats without transposes (see Section 3.1 for more details).\n",
"- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n",
@@ -244,10 +234,6 @@
" JAX | \n",
" cuDNN attention > JAX-native attention | \n",
" \n",
- " \n",
- " | PaddlePaddle | \n",
- " cuDNN attention > PaddlePaddle-native attention | \n",
- "
\n",
""
]
},
@@ -266,7 +252,7 @@
"\n",
"Note:\n",
" \n",
- "These flags are supported in PyTorch only as of Transformer Engine 1.10. JAX and PaddlePaddle support is expected to be added in the future.\n",
+ "These flags are supported in PyTorch only as of Transformer Engine 2.0. JAX support is expected to be added in the future.\n",
"
"
]
},
@@ -382,7 +368,7 @@
"\n",
"Note\n",
" \n",
- "Environment variables NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT and NVTE_ALLOW_NONDETERMINISTIC_ALGO are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n",
+ "Environment variables NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT and NVTE_ALLOW_NONDETERMINISTIC_ALGO are only supported in PyTorch, and will be added to JAX in the future.\n",
"
\n",
"\n",
"### 2.3 Example Tests\n",
@@ -399,7 +385,7 @@
"source": [
"## 3. Backend Support\n",
"\n",
- "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.10, Transformer Engine's attention backends have the following support matrix.\n",
+ "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v2.0, Transformer Engine's attention backends have the following support matrix.\n",
"\n",
"| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n",
"| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n",
@@ -442,7 +428,7 @@
"**qkv_layout=thd_thd_thd:**\n",
"`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n",
"\n",
- "As of v1.10, Transformer Engine has the following support matrix.\n",
+ "As of v2.0, Transformer Engine has the following support matrix.\n",
"\n",
"\n",
" \n",
@@ -462,13 +448,13 @@
"
\n",
" \n",
" | \n",
- " JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n",
+ " JAX: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n",
" | \n",
"
\n",
" \n",
" | Framework-native attention | \n",
" `bshd`, `sbhd` | \n",
- " PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts | \n",
+ " PyTorch, JAX: 2 formats, i.e. 10 layouts | \n",
"
\n",
"
\n",
"\n",
@@ -492,7 +478,7 @@
"\n",
"- `no_mask`, `padding`, `causal`, `causal_bottom_right`, `padding_causal`, `padding_causal_bottom_right`, `arbitrary`\n",
"\n",
- "Different backends offer different support for attention mask. As of Transformer Engine 1.10,\n",
+ "Different backends offer different support for attention mask. As of Transformer Engine 2.0,\n",
"\n",
"\n",
" \n",
@@ -512,21 +498,21 @@
"
\n",
" \n",
" | Framework-native attention | \n",
- " All (PyTorch)`no_mask`, `causal`, `padding` (Jax, PaddlePaddle) | \n",
+ " All (PyTorch)`no_mask`, `causal`, `padding` (Jax) | \n",
"
\n",
" \n",
" | \n",
"
\n",
"
\n",
"\n",
- "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.10, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n",
+ "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 2.0, there are two options to do so in PyTorch and one in JAX.\n",
"\n",
"* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n",
" - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n",
" - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n",
"\n",
"\n",
- "* JAX and PaddlePaddle: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
+ "* JAX: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n",
"\n",
"**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n",
"\n",
@@ -566,7 +552,7 @@
"\n",
"### 3.3 Attention Bias\n",
"\n",
- "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.10, their support matrix is as follows.\n",
+ "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 2.0, their support matrix is as follows.\n",
"\n",
"\n",
" \n",
@@ -591,7 +577,7 @@
" | cuDNN 8.9.6+: sm90 | \n",
"
\n",
" \n",
- " | JAX, PaddlePaddle: `no_bias`, `post_scale_bias` | \n",
+ " JAX: `no_bias`, `post_scale_bias` | \n",
" ALiBi slopes: FP32 | \n",
" cuDNN 9.0+: sm80+ | \n",
"
\n",
@@ -620,7 +606,7 @@
"\n",
"A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n",
"\n",
- "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.10. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n",
+ "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v2.0. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n",
"\n",
"- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n",
"\n",
diff --git a/docs/installation.rst b/docs/installation.rst
index fae01c64fae8443ae90e6dbf8a2d42fe7c39226d..ee7afa9006bd66709849f76d167cfc66276259b0 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -37,7 +37,7 @@ Transformer Engine can be directly installed from `our PyPI desired_test_accuracy
-
- @unittest.skipIf(
- paddle.device.cuda.get_device_capability() < (8, 0),
- "BF16 MNIST example requires Ampere+ GPU",
- )
- def test_te_bf16(self):
- """Test Transformer Engine with BF16"""
- self.args.use_te = True
- self.args.use_fp8 = False
- self.args.save_model = True
- actual = train_and_evaluate(self.args)
- if os.path.exists("mnist_cnn.pdparams"):
- os.remove("mnist_cnn.pdparams")
- self.verify(actual)
-
- @unittest.skipIf(not gpu_has_fp8, reason)
- def test_te_fp8(self):
- """Test Transformer Engine with FP8"""
- self.args.use_te = True
- self.args.use_fp8 = True
- self.args.save_model = True
- actual = train_and_evaluate(self.args)
- if os.path.exists("mnist_cnn.pdparams"):
- os.remove("mnist_cnn.pdparams")
- self.verify(actual)
-
- @unittest.skipIf(not gpu_has_fp8, reason)
- def test_te_fp8_calibration(self):
- """Test Transformer Engine with FP8 calibration"""
- self.args.use_te = True
- self.args.use_fp8 = False
- self.args.use_fp8_infer = True
- actual = train_and_evaluate(self.args)
- if os.path.exists("mnist_cnn.pdparams"):
- os.remove("mnist_cnn.pdparams")
- self.verify(actual)
-
-
-if __name__ == "__main__":
- train_and_evaluate(mnist_parser(None))
diff --git a/pylintrc b/pylintrc
index b80679d72c6c0ef61182685c16320a786a0e136c..4af0c6b4271d6077cbb23505f048317884e8bf08 100644
--- a/pylintrc
+++ b/pylintrc
@@ -2,7 +2,6 @@
extension-pkg-whitelist=flash_attn_2_cuda,
torch,
transformer_engine_torch,
- transformer_engine_paddle,
transformer_engine_jax
extension-pkg-allow-list=transformer_engine.transformer_engine_jax
diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh
index 6eff0477212592ffc1726889820df94f3f197c36..8e2e540293673b896f9e81b4a305e2b87358a74e 100644
--- a/qa/L0_jax_unittest/test.sh
+++ b/qa/L0_jax_unittest/test.sh
@@ -8,7 +8,7 @@ pip install "nltk>=3.8.2"
pip install pytest==8.2.1
: ${TE_PATH:=/opt/transformerengine}
-pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed'
+pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py
# Test without custom calls
NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py
diff --git a/qa/L0_paddle_lint/test.sh b/qa/L0_paddle_lint/test.sh
deleted file mode 100644
index 1c26bd265bae30d0480a7d15962a657159e9c05a..0000000000000000000000000000000000000000
--- a/qa/L0_paddle_lint/test.sh
+++ /dev/null
@@ -1,24 +0,0 @@
-# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# See LICENSE for license information.
-
-set -e
-
-: "${TE_PATH:=/opt/transformerengine}"
-
-pip install cpplint==1.6.0 pylint==3.3.1
-if [ -z "${PYTHON_ONLY}" ]
-then
- cd $TE_PATH
- echo "Checking common API headers"
- cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include
- echo "Checking C++ files"
- cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common
- cpplint --recursive transformer_engine/paddle
-fi
-if [ -z "${CPP_ONLY}" ]
-then
- cd $TE_PATH
- echo "Checking Python files"
- pylint --recursive=y transformer_engine/common transformer_engine/paddle
-fi
diff --git a/qa/L0_paddle_unittest/test.sh b/qa/L0_paddle_unittest/test.sh
deleted file mode 100644
index 9312f22ba4dc0f9a67ef4f48fcb5f0f0ca283867..0000000000000000000000000000000000000000
--- a/qa/L0_paddle_unittest/test.sh
+++ /dev/null
@@ -1,10 +0,0 @@
-# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# See LICENSE for license information.
-
-set -xe
-
-pip install pytest==8.2.1
-: ${TE_PATH:=/opt/transformerengine}
-pytest -Wignore -v $TE_PATH/tests/paddle
-pytest -Wignore -v $TE_PATH/examples/paddle/mnist
diff --git a/qa/L0_paddle_wheel/test.sh b/qa/L0_paddle_wheel/test.sh
deleted file mode 100644
index 5116bdb5cfc41a2bf392c191b80f56407477e1b9..0000000000000000000000000000000000000000
--- a/qa/L0_paddle_wheel/test.sh
+++ /dev/null
@@ -1,37 +0,0 @@
-# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# See LICENSE for license information.
-
-set -e
-
-: "${TE_PATH:=/opt/transformerengine}"
-
-# Install dependencies
-# Note: Need to install wheel locally since PaddlePaddle container
-# already contains APT install.
-pip install pydantic
-pip install --user wheel==0.44.0
-
-cd $TE_PATH
-pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle
-
-VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
-WHL_BASE="transformer_engine-${VERSION}"
-
-# Core wheel.
-NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
-python -m wheel unpack dist/*
-sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
-sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
-mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
-python -m wheel pack ${WHL_BASE}
-rm dist/*.whl
-mv *.whl dist/
-NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel
-pip install dist/*.whl --no-deps
-
-cd transformer_engine/paddle
-NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
-pip install dist/*
-
-python $TE_PATH/tests/paddle/test_sanity_import.py
diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh
index 793fa472598d8b279a12ff4c6dc2485222b23221..dd7f95bce0aeb9826f6a7800c51a81c014f2500e 100644
--- a/qa/L0_pytorch_unittest/test.sh
+++ b/qa/L0_pytorch_unittest/test.sh
@@ -11,11 +11,10 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
-PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
+NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
-pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh
index ee7c28ca5f1f380b002ff5d96f94729271664130..8ee0be1af5b063f7a553406a45d8a902e1cb496f 100644
--- a/qa/L1_pytorch_distributed_unittest/test.sh
+++ b/qa/L1_pytorch_distributed_unittest/test.sh
@@ -8,8 +8,8 @@ set -e
pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
-pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
-pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py
+pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
+# pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py ### TODO Debug UB support with te.Sequential
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
diff --git a/qa/L1_pytorch_onnx_test/test.sh b/qa/L1_pytorch_onnx_test/test.sh
deleted file mode 100644
index 8e4ef03b8e8252179240e6cf7ef3a3f8e71af0bb..0000000000000000000000000000000000000000
--- a/qa/L1_pytorch_onnx_test/test.sh
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# See LICENSE for license information.
-
-set -e
-
-: ${TE_PATH:=/opt/transformerengine}
-
-pip install pytest==8.2.1 onnxruntime==1.19.2
-
-# Build custom ONNX Runtime operators
-export CUSTOM_ORT_OPS_PATH=$TE_PATH/tests/pytorch/custom_ort_ops
-bash $CUSTOM_ORT_OPS_PATH/build.sh
-
-# Run tests
-NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh
index e63ba358a533c9a70b35a560d462d6dc938ff75e..8ed300221428616cb02e629cca747e441ca29e6b 100644
--- a/qa/L3_pytorch_FA_versions_test/test.sh
+++ b/qa/L3_pytorch_FA_versions_test/test.sh
@@ -12,7 +12,14 @@ pip install pytest==8.2.1
export MAX_JOBS=4
# Iterate over Flash Attention versions
-FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.6.3 3.0.0b1)
+sm_arch=`python -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"`
+if [ $sm_arch -gt 90 ]
+then
+ FA_versions=(2.7.3)
+else
+ FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
+fi
+
for fa_version in "${FA_versions[@]}"
do
@@ -21,10 +28,10 @@ do
then
pip install flash-attn==${fa_version}
else
- pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
+ pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flashattn_hopper
- wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py
+ wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py
fi
# Run tests
diff --git a/setup.py b/setup.py
index 643dd7a9085407a267c91bf0b461a73037f03668..1d9818458e7d8127bf3e48949ad5380ee64a238d 100644
--- a/setup.py
+++ b/setup.py
@@ -5,6 +5,7 @@
"""Installation script."""
import os
+import sys
import time
from pathlib import Path
from typing import List, Tuple
@@ -35,14 +36,13 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1"
if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension
-elif "paddle" in frameworks:
- from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension
CMakeBuildExtension = get_build_ext(BuildExtension)
+archs = cuda_archs()
class TimedBdist(bdist_wheel):
@@ -57,7 +57,7 @@ class TimedBdist(bdist_wheel):
def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library"""
- cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())]
+ cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
@@ -104,13 +104,11 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch"])
- test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"])
+ test_reqs.extend(["numpy", "torchvision", "prettytable"])
if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"])
- test_reqs.extend(["numpy", "praxis"])
- if "paddle" in frameworks:
- install_reqs.append("paddlepaddle-gpu")
- test_reqs.append("numpy")
+ # test_reqs.extend(["numpy", "praxis"])
+ test_reqs.extend(["numpy"])
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
@@ -135,7 +133,6 @@ if __name__ == "__main__":
extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
- "paddle": [f"transformer_engine_paddle=={__version__}"],
}
else:
setup_requires, install_requires, test_requires = setup_requirements()
@@ -169,16 +166,6 @@ if __name__ == "__main__":
current_file_path / "transformer_engine",
)
)
- if "paddle" in frameworks:
- from build_tools.paddle import setup_paddle_extension
-
- ext_modules.append(
- setup_paddle_extension(
- "transformer_engine/paddle/csrc",
- current_file_path / "transformer_engine" / "paddle" / "csrc",
- current_file_path / "transformer_engine",
- )
- )
# Configure package
setuptools.setup(
diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt
index d8c8d99fac72568fac2e0de3171c4e2d82cd9ad7..081cd14eb4309ea9349095133d031f3ff10d8fbf 100644
--- a/tests/cpp/CMakeLists.txt
+++ b/tests/cpp/CMakeLists.txt
@@ -5,7 +5,11 @@
cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
- set(CMAKE_CUDA_ARCHITECTURES 70 80 90)
+ if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
+ set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
+ else ()
+ set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90)
+ endif()
endif()
diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt
index 178dc5e8ddb795ef0d5d67f68031482915901c00..ce78fcaae245cbc4718129f0d737c94dd62feb61 100644
--- a/tests/cpp/operator/CMakeLists.txt
+++ b/tests/cpp/operator/CMakeLists.txt
@@ -3,23 +3,33 @@
# See LICENSE for license information.
add_executable(test_operator
+ test_cast.cu
+ test_cast_dbias.cu
+ test_cast_dbias_dgelu.cu
+ test_cast_gated_swiglu.cu
+ test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
- test_cast_transpose.cu
+ test_cast_mxfp8.cu
+ test_dequantize_mxfp8.cu
test_transpose.cu
+ test_cast_transpose.cu
test_cast_transpose_dbias.cu
test_cast_transpose_dbias_dgelu.cu
test_cast_transpose_dgeglu.cu
test_act.cu
test_normalization.cu
+ test_normalization_mxfp8.cu
test_multi_cast_transpose.cu
test_multi_padding.cu
test_causal_softmax.cu
+ test_swizzle.cu
../test_common.cu)
+find_package(OpenMP REQUIRED)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
-target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS})
-target_compile_options(test_operator PRIVATE -O2)
+target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX)
+target_compile_options(test_operator PRIVATE -O2 -fopenmp)
include(GoogleTest)
-gtest_discover_tests(test_operator)
+gtest_discover_tests(test_operator DISCOVERY_TIMEOUT 600)
diff --git a/tests/cpp/operator/test_act.cu b/tests/cpp/operator/test_act.cu
index cec997d0781c63e7db49fcb57e8cd539bf613901..4224f199f45ce8ae3e80f0924f4c15465317af2e 100644
--- a/tests/cpp/operator/test_act.cu
+++ b/tests/cpp/operator/test_act.cu
@@ -21,58 +21,6 @@
using namespace transformer_engine;
-namespace {
-
-// forward
-
-float gelu(const float x) {
- return 0.5f * x * (1.0f + tanhf(0.79788456F * x * (1.0f + 0.044715f * x * x)));
-}
-
-float silu(const float x) {
- return x / (1 + expf(-x));
-}
-
-float relu(const float x) {
- return x > 0 ? x : 0;
-}
-
-float srelu(const float x) {
- return x > 0 ? x * x : 0;
-}
-
-float qgelu(const float x) {
- return x / (1 + expf(-1.702f * x));
-}
-
-// backward
-
-float dgelu(const float x) {
- const float tanh_out = tanhf(0.79788456f * x * (1.f + 0.044715f * x * x));
- return 0.5f * x * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) +
- 0.5f * (1.f + tanh_out);
-}
-
-float dsilu(const float x) {
- const float sigmoid = 1.f / (1 + expf(-x));
- return x * sigmoid * (1.f - sigmoid) + sigmoid;
-}
-
-float drelu(const float x) {
- return x > 0.f ? 1.f : 0.f;
-}
-
-float dsrelu(const float x) {
- return fmaxf(2.f * x, 0.f);
-}
-
-float dqgelu(const float x) {
- const float sigmoid = 1.f / (1 + expf(-1.702f * x));
- return 1.702f * x * sigmoid * (1.f - sigmoid) + sigmoid;
-}
-
-} // namespace
-
template
void compute_ref_act_cast(const IT *input_h,
OT *output_h,
@@ -82,6 +30,7 @@ void compute_ref_act_cast(const IT *input_h,
const size_t H) {
CT amax = 0.;
+ #pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast(input_h[i * H + j]);
@@ -101,6 +50,7 @@ void compute_ref_dact_cast(const IT *input_h,
const size_t N,
const size_t H) {
using CT = float;
+ #pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast(input_h[i * H + j]);
@@ -118,6 +68,7 @@ void compute_ref_glu_act_cast(const IT *input_h, OT *output_h, const CT scale, C
const int col = H * 2;
+ #pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT gelu_elt = static_cast(input_h[i * col + j]);
@@ -139,6 +90,7 @@ void compute_ref_dglu_act_cast(const IT *input_h, const IT *grad_h, OT *output_h
const int col = H * 2;
using CT = float;
+ #pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT grad = static_cast(grad_h[i * H + j]);
@@ -164,10 +116,10 @@ void performTest(const size_t N, const size_t H) {
DType itype = TypeInfo::dtype;
DType otype = TypeInfo::dtype;
- Tensor input({ N, H }, itype);
- Tensor output({ N, H }, otype);
- Tensor igrad({ N, H }, itype);
- Tensor ograd({ N, H }, itype);
+ Tensor input("input", { N, H }, itype);
+ Tensor output("output", { N, H }, otype);
+ Tensor igrad("igrad", { N, H }, itype);
+ Tensor ograd("ograd", { N, H }, itype);
fillUniform(&input);
fillUniform(&ograd);
@@ -179,7 +131,7 @@ void performTest(const size_t N, const size_t H) {
nvte_act(input.data(), output.data(), 0);
float ref_amax;
- compute_ref_act_cast(input.cpu_dptr(), ref_output.get(),
+ compute_ref_act_cast(input.rowwise_cpu_dptr(), ref_output.get(),
output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize();
@@ -195,7 +147,7 @@ void performTest(const size_t N, const size_t H) {
nvte_dact(ograd.data(), input.data(), igrad.data(), 0);
- compute_ref_dact_cast(input.cpu_dptr(), ograd.cpu_dptr(),
+ compute_ref_dact_cast(input.rowwise_cpu_dptr(), ograd.rowwise_cpu_dptr(),
ref_igrad.get(), N, H);
cudaDeviceSynchronize();
@@ -219,10 +171,10 @@ void performTestGLU(const size_t N, const size_t H) {
DType itype = TypeInfo::dtype;
DType otype = TypeInfo::dtype;
- Tensor input({N, H * 2}, itype);
- Tensor output({N, H}, otype);
- Tensor igrad({ N, H * 2 }, itype);
- Tensor ograd({ N, H }, itype);
+ Tensor input("input", {N, H * 2}, itype);
+ Tensor output("output", {N, H}, otype);
+ Tensor igrad("igrad", { N, H * 2 }, itype);
+ Tensor ograd("ograd", { N, H }, itype);
fillUniform(&input);
fillUniform(&ograd);
@@ -234,7 +186,7 @@ void performTestGLU(const size_t N, const size_t H) {
nvte_act(input.data(), output.data(), 0);
float ref_amax;
- compute_ref_glu_act_cast(input.cpu_dptr(), ref_output.get(),
+ compute_ref_glu_act_cast(input.rowwise_cpu_dptr(), ref_output.get(),
output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize();
@@ -242,15 +194,19 @@ void performTestGLU(const size_t N, const size_t H) {
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
- auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
- compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
+ auto [atol, rtol] = getTolerances(DType::kFloat32);
+ compareResults("amax", output.amax(), ref_amax, atol, rtol);
+ if (output.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
+ const float ref_scale = 1.f / output.scale();
+ compareResults("scale_inv", *output.rowwise_cpu_scale_inv_ptr(), ref_scale, atol, rtol);
+ }
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol);
nvte_dact(ograd.data(), input.data(), igrad.data(), 0);
- compute_ref_dglu_act_cast(input.cpu_dptr(), ograd.cpu_dptr(),
+ compute_ref_dglu_act_cast(input.rowwise_cpu_dptr(), ograd.rowwise_cpu_dptr(),
ref_igrad.get(), N, H);
cudaDeviceSynchronize();
diff --git a/tests/cpp/operator/test_cast.cu b/tests/cpp/operator/test_cast.cu
new file mode 100644
index 0000000000000000000000000000000000000000..f57d1f035df6487965a395efb2aa6264b10fd7ec
--- /dev/null
+++ b/tests/cpp/operator/test_cast.cu
@@ -0,0 +1,130 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+#include "../test_common.h"
+
+using namespace transformer_engine;
+
+namespace {
+
+template
+void compute_ref(const InputType *data, OutputType *output_c,
+ const size_t size,
+ float *amax, float scale) {
+ using compute_t = float;
+ compute_t current_max = -1e100;
+ for (size_t i = 0; i < size; ++i) {
+ compute_t current = static_cast(data[i]);
+ current_max = fmaxf(current_max, fabsf(current));
+ output_c[i] = OutputType(scale * current);
+ }
+ *amax = current_max;
+}
+
+template
+void performTest(const std::vector& shape) {
+ using namespace test;
+
+ const size_t full_size = product(shape);
+
+ DType itype = TypeInfo::dtype;
+ DType otype = TypeInfo::dtype;
+
+ Tensor input("input", shape, itype);
+ Tensor output_c("output_c", shape, otype);
+
+ std::unique_ptr ref_output_c = std::make_unique(full_size);
+
+ fillUniform(&input);
+ setRandomScale(&output_c);
+
+ nvte_quantize(input.data(), output_c.data(), 0);
+
+ float ref_amax;
+ compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(),
+ full_size, &ref_amax, output_c.scale());
+
+ cudaDeviceSynchronize();
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+ if (isFp8Type(otype)) {
+ auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
+ compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
+ float ref_scale_inv = 1.f / output_c.scale();
+ compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
+ }
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
+}
+
+std::vector> test_cases = {
+ {16},
+ {16000},
+ {128, 128},
+ {256, 256},
+ {768, 1024},
+ {256, 65536},
+ {2048, 12288},
+ {65536, 128},
+ {65536, 160},
+ {16384, 1616},
+ {1, 128},
+ {1, 1296},
+ {1, 16},
+ {5, 160},
+ {5, 4, 3, 160},
+ {217, 256},
+};
+} // namespace
+
+class CastTestSuite : public ::testing::TestWithParam>> {};
+
+TEST_P(CastTestSuite, TestCast) {
+ using namespace transformer_engine;
+ using namespace test;
+
+ const DType input_type = std::get<0>(GetParam());
+ const DType output_type = std::get<1>(GetParam());
+ const auto size = std::get<2>(GetParam());
+
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
+ performTest(size);
+ );
+ );
+}
+
+
+
+INSTANTIATE_TEST_SUITE_P(
+ OperatorTest,
+ CastTestSuite,
+ ::testing::Combine(
+ ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
+ ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
+ ::testing::ValuesIn(test_cases)),
+ [](const testing::TestParamInfo& info) {
+ std::string name = test::typeName(std::get<0>(info.param)) + "X" +
+ test::typeName(std::get<1>(info.param));
+ const auto& shape = std::get<2>(info.param);
+ for ( const auto& s: shape) {
+ name += "X" + std::to_string(s);
+ }
+ return name;
+ });
diff --git a/tests/cpp/operator/test_cast_dbias.cu b/tests/cpp/operator/test_cast_dbias.cu
new file mode 100644
index 0000000000000000000000000000000000000000..1f0a9305d834c0f5b83a6ee542d904b5af0dbc4e
--- /dev/null
+++ b/tests/cpp/operator/test_cast_dbias.cu
@@ -0,0 +1,181 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+#include "../test_common.h"
+
+using namespace transformer_engine;
+
+namespace {
+
+template
+void compute_ref_cast_dbias(const IT *input_h,
+ const CT scale,
+ OT *output_c_h,
+ CT *amax_h,
+ IT *dbias_h,
+ const size_t N,
+ const size_t H) {
+ CT amax = 0.;
+
+ std::vector acc_dbias(H, 0.);
+
+ for (size_t i = 0; i < N; i++) {
+ for (size_t j = 0; j < H; j++) {
+ CT elt = static_cast(input_h[i * H + j]);
+
+ // update amax
+ amax = std::abs(elt) > amax ? std::abs(elt) : amax;
+
+ output_c_h[i * H + j] = static_cast(scale * elt);
+
+ // dbias
+ acc_dbias[j] += elt;
+ }
+ }
+
+ *amax_h = amax;
+
+ for (size_t i = 0; i < H; i++) {
+ dbias_h[i] = static_cast(acc_dbias[i]);
+ }
+}
+
+template
+void performTest(const std::vector& shape) {
+ using namespace test;
+ using CType = fp32;
+
+ DType itype = TypeInfo::dtype;
+ DType otype = TypeInfo::dtype;
+
+ const size_t N = first_dimension(shape);
+ const size_t H = last_dimension(shape);
+
+ Tensor input("input", shape, itype);
+
+ Tensor output_c("output_c", shape, otype);
+ // dbias has the same data type with "output grad"
+ Tensor dbias("dbias", {H}, itype);
+
+ fillUniform(&input);
+ setRandomScale(&output_c);
+
+ std::unique_ptr ref_output_c = std::make_unique(N*H);
+ std::unique_ptr ref_output_dbias = std::make_unique(H);
+
+ CType ref_amax;
+ compute_ref_cast_dbias(input.rowwise_cpu_dptr(),
+ output_c.scale(),
+ ref_output_c.get(),
+ &ref_amax,
+ ref_output_dbias.get(),
+ N, H);
+
+ Tensor workspace;
+
+ nvte_quantize_dbias(input.data(),
+ output_c.data(),
+ dbias.data(),
+ workspace.data(),
+ 0);
+
+ workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
+
+ nvte_quantize_dbias(input.data(),
+ output_c.data(),
+ dbias.data(),
+ workspace.data(),
+ 0);
+
+ cudaDeviceSynchronize();
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+
+ if (isFp8Type(otype)) {
+ auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
+ compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
+ float ref_scale_inv = 1.f / output_c.scale();
+ compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
+ }
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
+
+ auto [atol_dbias, rtol_dbias] = getTolerances(itype);
+ rtol_dbias *= 4;
+ compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
+}
+
+std::vector> test_cases = {
+ {128, 128},
+ {256, 256},
+ {768, 1024},
+ {256, 65536},
+ {2048, 12288},
+ {65536, 128},
+ {65536, 160},
+ {16384, 1616},
+ {1, 128},
+ {1, 1296},
+ {1, 16},
+ {5, 160},
+ {5, 4, 3, 160},
+ {217, 256},
+};
+
+} // namespace;
+
+
+class CastDBiasTestSuite : public ::testing::TestWithParam>> {};
+
+TEST_P(CastDBiasTestSuite, TestCastDBias) {
+ using namespace transformer_engine;
+ using namespace test;
+ // Skip tests for pre-Blackwell architectures
+ if (getDeviceComputeCapability() < blackwellComputeCapability) {
+ GTEST_SKIP();
+ }
+
+ const DType input_type = std::get<0>(GetParam());
+ const DType output_type = std::get<1>(GetParam());
+ const auto size = std::get<2>(GetParam());
+
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
+ performTest(size);
+ );
+ );
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ OperatorTest,
+ CastDBiasTestSuite,
+ ::testing::Combine(
+ ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
+ ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
+ ::testing::ValuesIn(test_cases)),
+ [](const testing::TestParamInfo& info) {
+ std::string name = test::typeName(std::get<0>(info.param)) + "X" +
+ test::typeName(std::get<1>(info.param));
+ const auto& shape = std::get<2>(info.param);
+ for ( const auto& s: shape) {
+ name += "X" + std::to_string(s);
+ }
+ return name;
+ });
diff --git a/tests/cpp/operator/test_cast_dbias_dgelu.cu b/tests/cpp/operator/test_cast_dbias_dgelu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..20ea5c31f181e252fd413f69bb1a021cb10fae05
--- /dev/null
+++ b/tests/cpp/operator/test_cast_dbias_dgelu.cu
@@ -0,0 +1,196 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+#include "../test_common.h"
+
+using namespace transformer_engine;
+using namespace test;
+
+namespace {
+
+template
+void compute_ref_cast_dbias_dgelu(const IT *input,
+ const IT *grad,
+ const CT scale,
+ OT *output_c,
+ CT *amax_h,
+ IT *dbias,
+ const size_t N,
+ const size_t H) {
+ CT amax = 0.;
+
+ std::vector acc_dbias(H, 0.);
+
+ for (size_t i = 0; i < N; i++) {
+ for (size_t j = 0; j < H; j++) {
+ CT in_elt = static_cast(input[i * H + j]);
+ const CT in_grad = static_cast(grad[i * H + j]);
+
+ const CT elt = in_grad * static_cast(dgelu(static_cast(in_elt)));
+ const CT elt_abs = std::abs(elt);
+
+ // update amax
+ if (elt_abs > amax) {
+ amax = elt_abs;
+ }
+
+ output_c[i * H + j] = static_cast(scale * elt);
+
+ // dbias
+ acc_dbias[j] += elt;
+ }
+ }
+
+ *amax_h = amax;
+
+ for (size_t i = 0; i < H; i++) {
+ dbias[i] = static_cast(acc_dbias[i]);
+ }
+}
+
+template
+void performTest(const std::vector& shape) {
+ using namespace test;
+ using CType = fp32;
+
+ DType itype = TypeInfo::dtype;
+ DType otype = TypeInfo::dtype;
+
+ const size_t N = first_dimension(shape);
+ const size_t H = last_dimension(shape);
+
+ Tensor input("input", shape, itype);
+ Tensor grad("grad", shape, itype);
+
+ Tensor output_c("output_c", shape, otype);
+ // dbias has the same data type with "output grad"
+ Tensor dbias("dbias", {H}, itype);
+
+ fillUniform(&input);
+ fillUniform(&grad);
+ setRandomScale(&output_c);
+
+ std::unique_ptr ref_output_c = std::make_unique(N*H);
+ std::unique_ptr ref_output_dbias = std::make_unique(H);
+
+ CType ref_amax;
+ compute_ref_cast_dbias_dgelu(input.rowwise_cpu_dptr(),
+ grad.rowwise_cpu_dptr(),
+ output_c.scale(),
+ ref_output_c.get(),
+ &ref_amax,
+ ref_output_dbias.get(),
+ N, H);
+
+ Tensor workspace;
+
+ nvte_quantize_dbias_dgelu(grad.data(),
+ input.data(),
+ output_c.data(),
+ dbias.data(),
+ workspace.data(),
+ 0);
+
+ workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
+
+
+ nvte_quantize_dbias_dgelu(grad.data(),
+ input.data(),
+ output_c.data(),
+ dbias.data(),
+ workspace.data(),
+ 0);
+
+ cudaDeviceSynchronize();
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+
+ if (isFp8Type(otype)) {
+ auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
+ compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
+ float ref_scale_inv = 1.f / output_c.scale();
+ compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
+ }
+
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
+
+ auto [atol_dbias, rtol_dbias] = getTolerances(itype);
+ rtol_dbias *= 4;
+ compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
+}
+
+std::vector> test_cases = {
+ {128, 128},
+ {256, 256},
+ {768, 1024},
+ {256, 65536},
+ {2048, 12288},
+ {65536, 128},
+ {65536, 160},
+ {16384, 1616},
+ {1, 128},
+ {1, 1296},
+ {1, 16},
+ {5, 160},
+ {5, 4, 3, 160},
+ {217, 256},
+};
+
+} // namespace;
+
+
+class CastDBiasDGeluTestSuite : public ::testing::TestWithParam>> {};
+
+TEST_P(CastDBiasDGeluTestSuite, TestCastDBiasDgelu) {
+ using namespace transformer_engine;
+ using namespace test;
+ // Skip tests for pre-Blackwell architectures
+ if (getDeviceComputeCapability() < blackwellComputeCapability) {
+ GTEST_SKIP();
+ }
+
+ const DType input_type = std::get<0>(GetParam());
+ const DType output_type = std::get<1>(GetParam());
+ const auto size = std::get<2>(GetParam());
+
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
+ performTest(size);
+ );
+ );
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ OperatorTest,
+ CastDBiasDGeluTestSuite,
+ ::testing::Combine(
+ ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
+ ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
+ ::testing::ValuesIn(test_cases)),
+ [](const testing::TestParamInfo& info) {
+ std::string name = test::typeName(std::get<0>(info.param)) + "X" +
+ test::typeName(std::get<1>(info.param));
+ const auto& shape = std::get<2>(info.param);
+ for ( const auto& s: shape) {
+ name += "X" + std::to_string(s);
+ }
+ return name;
+ });
diff --git a/tests/cpp/operator/test_cast_gated_swiglu.cu b/tests/cpp/operator/test_cast_gated_swiglu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..35ae4621062e6d68df25a1cda418560d6b34b5f4
--- /dev/null
+++ b/tests/cpp/operator/test_cast_gated_swiglu.cu
@@ -0,0 +1,165 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include "../test_common.h"
+
+using namespace transformer_engine;
+using namespace test;
+
+namespace {
+
+template
+void compute_ref_cast_dgated_swiglu(const IType * const grad,
+ const IType * const input,
+ const float scale,
+ OType * const output,
+ float * const amax_ptr,
+ const size_t rows,
+ const size_t cols) {
+ float amax = 0;
+ const size_t stride = cols * 2;
+
+ #pragma omp parallel for reduction(max: amax) proc_bind(spread)
+ for (size_t i = 0; i < rows; i++) {
+ for (size_t j = 0; j < cols; j++) {
+ float grad_elt = static_cast(grad[i * cols + j]);
+ float silu_elt = static_cast(input[i * stride + j]);
+ float gate_elt = static_cast(input[i * stride + cols + j]);
+
+ float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
+ float after_dgate = grad_elt * silu(silu_elt);
+
+ if (abs(after_dsilu) > amax) { amax = abs(after_dsilu); }
+ if (abs(after_dgate) > amax) { amax = abs(after_dgate); }
+
+ output[i * stride + j] = static_cast(scale * after_dsilu);
+ output[i * stride + cols + j] = static_cast(scale * after_dgate);
+ }
+ }
+
+ *amax_ptr = amax;
+}
+
+template
+void performTest(const std::vector& shape) {
+ using namespace test;
+
+ DType itype = TypeInfo::dtype;
+ DType otype = TypeInfo::dtype;
+
+ std::vector input_shape = shape;
+ input_shape[input_shape.size() - 1] *= 2;
+
+ const size_t input_size = product(input_shape);
+
+ const size_t rows = first_dimension(shape);
+ const size_t cols = last_dimension(shape);
+
+ Tensor grad("grad", shape, itype);
+ Tensor input("input", input_shape, itype);
+ Tensor output_c("output_c", input_shape, otype);
+
+ fillUniform(&grad);
+ fillUniform(&input);
+ setRandomScale(&output_c);
+
+ std::unique_ptr ref_output_c = std::make_unique(input_size);
+
+ nvte_dswiglu(grad.data(), input.data(), output_c.data(), 0);
+ cudaDeviceSynchronize();
+
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+
+ float ref_amax;
+ compute_ref_cast_dgated_swiglu(grad.rowwise_cpu_dptr(),
+ input.rowwise_cpu_dptr(),
+ output_c.scale(),
+ ref_output_c.get(),
+ &ref_amax,
+ rows,
+ cols);
+
+ if (isFp8Type(otype)) {
+ auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
+ compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
+ float ref_scale_inv = 1.f / output_c.scale();
+ compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
+ }
+
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
+}
+
+std::vector> test_cases = {
+ {128, 128},
+ {256, 256},
+ {768, 1024},
+ {256, 65536},
+ {2048, 12288},
+ {65536, 128},
+ {217, 256},
+ {1296},
+ {5, 4, 3, 160},
+};
+
+} // namespace
+
+class CastSwiGLUTestSuite
+ : public ::testing::TestWithParam>> {};
+
+TEST_P(CastSwiGLUTestSuite, TestCastSwiGLU) {
+ using namespace transformer_engine;
+ using namespace test;
+ // Skip tests for pre-Blackwell architectures
+ if (getDeviceComputeCapability() < blackwellComputeCapability) {
+ GTEST_SKIP();
+ }
+
+ const DType input_type = std::get<0>(GetParam());
+ const DType output_type = std::get<1>(GetParam());
+ const auto size = std::get<2>(GetParam());
+
+ if (size.back() % 32 != 0) {
+ GTEST_SKIP();
+ }
+
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
+ input_type, InputType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
+ output_type, OutputType, performTest(size);););
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ OperatorTest, CastSwiGLUTestSuite,
+ ::testing::Combine(
+ ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
+ ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
+ ::testing::ValuesIn(test_cases)),
+ [](const testing::TestParamInfo &info) {
+ std::string name = test::typeName(std::get<0>(info.param)) + "X" +
+ test::typeName(std::get<1>(info.param));
+ const auto& shape = std::get<2>(info.param);
+ for ( const auto& s: shape) {
+ name += "X" + std::to_string(s);
+ }
+ return name;
+ });
diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu
new file mode 100644
index 0000000000000000000000000000000000000000..cb38a5a74afc1dc9aeeee64d5fa6147f190b7956
--- /dev/null
+++ b/tests/cpp/operator/test_cast_mxfp8.cu
@@ -0,0 +1,636 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include "../test_common.h"
+#include "transformer_engine/transformer_engine.h"
+
+using namespace transformer_engine;
+using namespace test;
+
+namespace {
+
+enum ProcessingMethod {
+ CAST_ONLY,
+ CAST_DBIAS,
+ CAST_DBIAS_DACT,
+ CAST_DACT,
+ CAST_ACT
+};
+
+enum ActivationType {
+ Identity,
+ GeLU,
+ SiLU,
+ ReLU,
+ QGeLU,
+ SReLU
+};
+
+template
+void scale_block(const ProcessingMethod processing_method,
+ const InputType* input,
+ const InputType* grad,
+ OutputType* output_c,
+ float* dbias,
+ fp8e8m0* output_scales,
+ const size_t scale_idx,
+ const size_t i_min,
+ const size_t i_max,
+ const size_t j_min,
+ const size_t j_max,
+ const size_t cols) {
+ float amax = 0.0f;
+
+ // Find the absolute maximum value in the block
+ for (size_t i = i_min; i < i_max; ++i) {
+ for (size_t j = j_min; j < j_max; ++j) {
+ const size_t idx = i * cols + j;
+ float elt = static_cast(input[idx]);
+ if (processing_method == ProcessingMethod::CAST_DBIAS) {
+ // grad is the input
+ elt = static_cast(grad[idx]);
+ }
+ if (processing_method != ProcessingMethod::CAST_ONLY
+ && processing_method != ProcessingMethod::CAST_DBIAS) {
+ elt = OP(elt);
+ }
+ if (processing_method == ProcessingMethod::CAST_DACT ||
+ processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
+ elt *= static_cast(grad[idx]);
+ }
+ dbias[j] += elt;
+ if (isinf(elt) || isnan(elt)) {
+ continue;
+ }
+ amax = std::max(amax, std::abs(elt));
+ }
+ }
+
+ const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_reciprocal());
+ const float scale_reciprocal = exp2f_rcp(biased_exponent);
+ output_scales[scale_idx] = biased_exponent;
+
+ // Quantize elements in the block
+ for (size_t i = i_min; i < i_max; ++i) {
+ for (size_t j = j_min; j < j_max; ++j) {
+ const size_t idx = i * cols + j;
+ float elt = static_cast(input[idx]);
+ if (processing_method == ProcessingMethod::CAST_DBIAS) {
+ // grad is the input
+ elt = static_cast(grad[idx]);
+ }
+ if (processing_method != ProcessingMethod::CAST_ONLY
+ && processing_method != ProcessingMethod::CAST_DBIAS) {
+ elt = OP(elt);
+ }
+ if (processing_method == ProcessingMethod::CAST_DACT ||
+ processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
+ elt *= static_cast(grad[idx]);
+ }
+ output_c[idx] = static_cast(elt * scale_reciprocal);
+ }
+ }
+}
+
+template
+void compute_ref_x1(const ProcessingMethod processing_method,
+ const InputType* input,
+ const InputType* grad,
+ OutputType* output_c,
+ fp8e8m0* output_scales,
+ InputType* output_dbias,
+ const size_t rows,
+ const size_t cols,
+ const size_t block_size_Y,
+ const size_t block_size_X,
+ const size_t scales_stride)
+{
+ std::vector output_dbias_fp32(cols, 0);
+
+ const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y;
+ const size_t blocks_X = (cols + block_size_X - 1) / block_size_X;
+
+ for (size_t ii = 0; ii < blocks_Y; ++ii) {
+ const size_t i_min = ii * block_size_Y;
+ const size_t i_max = std::min((ii + 1) * block_size_Y, rows);
+ for (size_t jj = 0; jj < blocks_X; ++jj) {
+ const size_t j_min = jj * block_size_X;
+ const size_t j_max = std::min((jj + 1) * block_size_X, cols);
+ const size_t scale_idx = ii * scales_stride + jj;
+ scale_block(
+ processing_method, input, grad, output_c, output_dbias_fp32.data(),
+ output_scales, scale_idx, i_min, i_max, j_min, j_max, cols);
+ }
+ }
+ for (size_t j = 0; j < cols; ++j) {
+ output_dbias[j] = static_cast(output_dbias_fp32[j]);
+ }
+}
+
+template
+void compute_ref_x2(const ProcessingMethod processing_method,
+ const InputType* input,
+ const InputType* grad,
+ OutputType* output_rowwise,
+ OutputType* output_colwise,
+ fp8e8m0* scales_rowwise,
+ fp8e8m0* scales_colwise,
+ InputType* output_dbias,
+ const size_t rows,
+ const size_t cols,
+ const size_t block_size_Y,
+ const size_t block_size_X,
+ const size_t scales_stride_rowwise,
+ const size_t scales_stride_colwise) {
+ compute_ref_x1(
+ processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias,
+ rows, cols, 1, block_size_X, scales_stride_rowwise);
+ compute_ref_x1(
+ processing_method, input, grad, output_colwise, scales_colwise, output_dbias,
+ rows, cols, block_size_Y, 1, scales_stride_colwise);
+}
+
+/**
+ * Scaling along single dimension (either rows or columns)
+ * Produces one set of output data and the corresponding data of the fused operation (dbias):
+ * 1) Scaled rows + row-wise scaling factors
+ * OR
+ * 2) Scaled columns + column-wise scaling factors
+ */
+
+template
+void performTest_x1(const ProcessingMethod processing_method,
+ const std::vector& shape,
+ const bool rowwise,
+ const bool colwise,
+ InputsFillCase fill_case) {
+ using namespace test;
+ using EncodingType = fp32;
+ DType itype = TypeInfo::dtype;
+ DType otype = TypeInfo::dtype;
+
+ const size_t rows = first_dimension(shape);
+ const size_t cols = last_dimension(shape);
+
+ if (shape.size() < 2 && colwise) {
+ GTEST_SKIP();
+ }
+
+ const size_t block_size_rows = rowwise ? 1 : 32;
+ const size_t block_size_cols = colwise ? 1 : 32;
+
+ const std::array scale_dims = get_scale_tensor_dims(rows, cols, block_size_rows,
+ block_size_cols);
+
+ const size_t unpadded_blocks_Y = scale_dims[0];
+ const size_t unpadded_blocks_X = scale_dims[1];
+ const size_t blocks_Y = scale_dims[2];
+ const size_t blocks_X = scale_dims[3];
+ const size_t scales_stride = blocks_X;
+
+ Tensor input("input", shape, itype);
+ Tensor grad("grad", shape, itype);
+ Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING);
+ Tensor output_dbias("output_dbias", { cols }, itype);
+
+ std::unique_ptr ref_output_c = std::make_unique(rows * cols);
+ std::unique_ptr ref_output_dbias = std::make_unique(cols);
+ std::unique_ptr ref_output_scales = std::make_unique(blocks_Y * blocks_X);
+
+ fillCase(&input, fill_case);
+ fillUniform(&grad);
+
+ Tensor workspace;
+ switch (processing_method) {
+ case ProcessingMethod::CAST_ONLY: {
+ nvte_quantize(input.data(), output_c.data(), 0);
+ break;
+ }
+ case ProcessingMethod::CAST_DBIAS: {
+ nvte_quantize_dbias(grad.data(),
+ output_c.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
+
+ nvte_quantize_dbias(grad.data(),
+ output_c.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ break;
+ }
+ case ProcessingMethod::CAST_DBIAS_DACT: {
+ nvte_quantize_dbias_dgelu(grad.data(),
+ input.data(),
+ output_c.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
+
+ nvte_quantize_dbias_dgelu(grad.data(),
+ input.data(),
+ output_c.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ break;
+ }
+ case ProcessingMethod::CAST_DACT: {
+ nvte_dgelu(grad.data(), input.data(), output_c.data(), 0);
+ break;
+ }
+ case ProcessingMethod::CAST_ACT: {
+ nvte_gelu(input.data(), output_c.data(), 0);
+ break;
+ }
+ }
+
+ cudaDeviceSynchronize();
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+
+ compute_ref_x1(processing_method,
+ input.rowwise_cpu_dptr(),
+ grad.rowwise_cpu_dptr(),
+ ref_output_c.get(),
+ ref_output_scales.get(),
+ ref_output_dbias.get(),
+ rows,
+ cols,
+ block_size_rows,
+ block_size_cols,
+ scales_stride);
+
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol);
+
+ const uint8_t * const gpu_scales_ptr = rowwise
+ ? output_c.rowwise_cpu_scale_inv_ptr()
+ : output_c.columnwise_cpu_scale_inv_ptr();
+
+ compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
+ unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
+
+ if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
+ auto [atol_dbias, rtol_dbias] = getTolerances(itype);
+ if (itype == DType::kFloat32) {
+ atol_dbias = 1e-4;
+ rtol_dbias *= sqrt(static_cast(rows)) ;
+ } else {
+ rtol_dbias *= 4;
+ }
+ compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
+ }
+}
+
+/**
+ * Scaling along both dimensions (rows and columns)
+ * Produces two sets of scaled output data and the corresponding data of the fused operation (dbias):
+ * 1) Scaled rows + row-wise scaling factors
+ * AND
+ * 2) Scaled columns + column-wise scaling factors
+ */
+template
+void performTest_x2(const ProcessingMethod processing_method,
+ const std::vector& shape,
+ const size_t block_size_rows,
+ const size_t block_size_cols,
+ InputsFillCase fill_case) {
+ using namespace test;
+ using EncodingType = fp32;
+ DType itype = TypeInfo::dtype;
+ DType otype = TypeInfo::dtype;
+
+ if (shape.size() < 2) {
+ GTEST_SKIP();
+ }
+
+ const size_t rows = first_dimension(shape);
+ const size_t cols = last_dimension(shape);
+
+ const std::array scale_dims_rowwise = get_scale_tensor_dims(rows, cols, 1, 32);
+ const std::array scale_dims_colwise = get_scale_tensor_dims(rows, cols, 32, 1);
+
+ const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0];
+ const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1];
+ const size_t blocks_Y_rowwise = scale_dims_rowwise[2];
+ const size_t blocks_X_rowwise = scale_dims_rowwise[3];
+ const size_t scales_stride_rowwise = blocks_X_rowwise;
+
+ const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0];
+ const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1];
+ const size_t blocks_Y_colwise = scale_dims_colwise[2];
+ const size_t blocks_X_colwise = scale_dims_colwise[3];
+ const size_t scales_stride_colwise = blocks_X_colwise;
+
+ Tensor input("input", shape, itype);
+ Tensor grad("grad", shape, itype);
+ Tensor output("output", shape, otype, true, true, NVTE_MXFP8_1D_SCALING);
+ Tensor output_dbias("output_dbias", { cols }, itype);
+
+ std::unique_ptr ref_output_c_rowwise = std::make_unique(rows * cols);
+ std::unique_ptr ref_output_c_colwise = std::make_unique(rows * cols);
+ std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_Y_rowwise * blocks_X_rowwise);
+ std::unique_ptr ref_scales_colwise = std::make_unique(blocks_Y_colwise * blocks_X_colwise);
+ std::unique_ptr ref_output_dbias = std::make_unique(cols);
+
+ fillCase(&input, fill_case);
+ fillUniform(&grad);
+
+ Tensor workspace;
+ switch (processing_method) {
+ case ProcessingMethod::CAST_ONLY: {
+ nvte_quantize(input.data(), output.data(), 0);
+ break;
+ }
+ case ProcessingMethod::CAST_DBIAS: {
+ nvte_quantize_dbias(grad.data(),
+ output.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
+
+ nvte_quantize_dbias(grad.data(),
+ output.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ break;
+ }
+ case ProcessingMethod::CAST_DBIAS_DACT: {
+ nvte_quantize_dbias_dgelu(grad.data(),
+ input.data(),
+ output.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
+
+ nvte_quantize_dbias_dgelu(grad.data(),
+ input.data(),
+ output.data(),
+ output_dbias.data(),
+ workspace.data(),
+ 0);
+ break;
+ }
+ case ProcessingMethod::CAST_DACT: {
+ nvte_dgelu(grad.data(), input.data(), output.data(), 0);
+ break;
+ }
+ case ProcessingMethod::CAST_ACT: {
+ nvte_gelu(input.data(), output.data(), 0);
+ break;
+ }
+ }
+
+ cudaDeviceSynchronize();
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+
+ compute_ref_x2(processing_method,
+ input.rowwise_cpu_dptr(),
+ grad.rowwise_cpu_dptr(),
+ ref_output_c_rowwise.get(),
+ ref_output_c_colwise.get(),
+ ref_scales_rowwise.get(),
+ ref_scales_colwise.get(),
+ ref_output_dbias.get(),
+ rows,
+ cols,
+ block_size_rows,
+ block_size_cols,
+ scales_stride_rowwise,
+ scales_stride_colwise);
+
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol);
+ compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol);
+ compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(),
+ ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
+ unpadded_blocks_X_rowwise, scales_stride_rowwise);
+ compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(),
+ ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
+ unpadded_blocks_X_colwise, scales_stride_colwise);
+
+ if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
+ auto [atol_dbias, rtol_dbias] = getTolerances(itype);
+ if (itype == DType::kFloat32) {
+ atol_dbias = 1e-4;
+ rtol_dbias *= sqrt(static_cast(rows)) ;
+ } else {
+ rtol_dbias *= 4;
+ }
+ compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
+ }
+}
+
+std::vector> matrix_sizes = {
+ {1, 16},
+ {16, 48},
+ {65, 96},
+ {128, 128},
+ {256, 256},
+ {993, 512},
+ {256, 65536},
+ {2048, 6144},
+ {16384, 128},
+ {32768, 160},
+ {4096, 1632},
+ {1024},
+ {8, 32, 1024},
+ {16, 8, 4, 512},
+};
+
+std::vector> block_sizes = {
+ {1, 32},
+ {32, 1},
+ {32, 32},
+};
+
+std::vector input_scenarios = {
+ InputsFillCase::uniform,
+ // InputsFillCase::zeros,
+ // InputsFillCase::zero_to_minNorm,
+ // InputsFillCase::minNorm_to_maxNorm,
+ // InputsFillCase::maxNorm_to_inf
+};
+
+std::vector processing_methods = {
+ ProcessingMethod::CAST_ONLY,
+ ProcessingMethod::CAST_DBIAS,
+ ProcessingMethod::CAST_DBIAS_DACT,
+ ProcessingMethod::CAST_DACT,
+ ProcessingMethod::CAST_ACT,
+};
+
+// Only GeLU activation tests are supported
+std::vector Activation_types = {
+ ActivationType::Identity,
+ ActivationType::GeLU,
+ // ActivationType::SiLU,
+ // ActivationType::ReLU,
+ // ActivationType::QGeLU,
+ // ActivationType::SReLU,
+};
+
+} // namespace
+
+class FusedCastMXFP8TestSuite : public ::testing::TestWithParam
+ ,
+ std::pair,
+ transformer_engine::DType,
+ transformer_engine::DType,
+ InputsFillCase>> {};
+
+#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
+switch (OP_FUNC_TYPE) { \
+ case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
+ case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \
+ case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \
+ case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \
+ case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \
+ case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \
+}
+
+#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
+switch (OP_FUNC_TYPE) { \
+ case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
+ case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \
+ case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \
+ case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \
+ case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \
+ case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \
+}
+
+TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) {
+ // Skip tests for pre-Blackwell architectures
+ if (getDeviceComputeCapability() < blackwellComputeCapability) {
+ GTEST_SKIP();
+ }
+
+ using namespace transformer_engine;
+ using namespace test;
+
+ const ProcessingMethod processing_method = std::get<0>(GetParam());
+ const ActivationType Act_type = std::get<1>(GetParam());
+ const auto matrix_size = std::get<2>(GetParam());
+ const auto block_size = std::get<3>(GetParam());
+ const DType input_type = std::get<4>(GetParam());
+ const DType output_type = std::get<5>(GetParam());
+ const InputsFillCase fill_case = std::get<6>(GetParam());
+
+ // Skips non Act tests if the Activation type is not an identity
+ if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
+ && Act_type != ActivationType::Identity) {
+ GTEST_SKIP();
+ }
+ // Skips Act tests if the Activation is an identity
+ if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
+ || processing_method == ProcessingMethod::CAST_DACT
+ || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) {
+ GTEST_SKIP();
+ }
+
+ const bool rowwise = block_size.second != 1;
+ const bool colwise = block_size.first != 1;
+ if (processing_method == ProcessingMethod::CAST_ACT) {
+ // Forward activations
+ ACT_FUNC_SWITCH(Act_type, OP,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
+ if (block_size.first == 1 || block_size.second == 1) {
+ performTest_x1(
+ processing_method, matrix_size,
+ rowwise, colwise, fill_case);
+ } else {
+ performTest_x2(
+ processing_method, matrix_size,
+ block_size.first, block_size.second, fill_case);
+ }
+ );
+ );
+ );
+ } else {
+ DACT_FUNC_SWITCH(Act_type, OP,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
+ if (block_size.first == 1 || block_size.second == 1) {
+ performTest_x1(
+ processing_method, matrix_size,
+ rowwise, colwise, fill_case);
+ } else {
+ performTest_x2(
+ processing_method, matrix_size,
+ block_size.first, block_size.second, fill_case);
+ }
+ );
+ );
+ );
+ }
+}
+
+std::string to_string(const ProcessingMethod method) {
+ switch (method) {
+ case ProcessingMethod::CAST_ONLY: return "CAST_ONLY";
+ case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS";
+ case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT";
+ case ProcessingMethod::CAST_DACT: return "CAST_DACT";
+ case ProcessingMethod::CAST_ACT: return "CAST_ACT";
+ default: return "";
+ }
+}
+
+std::string to_string(const ActivationType Act_type) {
+ switch (Act_type) {
+ case ActivationType::Identity: return "Identity";
+ case ActivationType::GeLU: return "GeLU";
+ case ActivationType::SiLU: return "SiLU";
+ case ActivationType::ReLU: return "ReLU";
+ case ActivationType::QGeLU: return "QGeLU";
+ case ActivationType::SReLU: return "SReLU";
+ default: return "";
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ OperatorTest,
+ FusedCastMXFP8TestSuite,
+ ::testing::Combine(
+ ::testing::ValuesIn(processing_methods),
+ ::testing::ValuesIn(Activation_types),
+ ::testing::ValuesIn(matrix_sizes),
+ ::testing::ValuesIn(block_sizes),
+ ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
+ ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
+ ::testing::ValuesIn(input_scenarios)),
+ [](const testing::TestParamInfo& info) {
+ std::string name = to_string(std::get<0>(info.param)) + "X" +
+ to_string(std::get<1>(info.param));
+ const auto& shape = std::get<2>(info.param);
+ for ( const auto& s: shape) {
+ name += "X" + std::to_string(s);
+ }
+ name += "X" + std::to_string(std::get<3>(info.param).first) +
+ "X" + std::to_string(std::get<3>(info.param).second) +
+ "X" + test::typeName(std::get<4>(info.param)) +
+ "X" + test::typeName(std::get<5>(info.param)) +
+ "X" + test::caseName(std::get<6>(info.param));
+ return name;
+ });
diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..6acbdefeabfd2cbbb6ae8c3144d6a37beab5a865
--- /dev/null
+++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
@@ -0,0 +1,470 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+#include
+#include
+
+#include
+#include "../test_common.h"
+#include "transformer_engine/transformer_engine.h"
+
+using namespace transformer_engine;
+using namespace test;
+
+namespace {
+
+template
+void scale_block(const IType* grad,
+ const IType* input,
+ OType* output,
+ fp8e8m0* output_scales,
+ const size_t scale_idx,
+ const size_t scale_idx_gate,
+ float& thread_amax,
+ const size_t i_min,
+ const size_t i_max,
+ const size_t j_min,
+ const size_t j_max,
+ const size_t cols) {
+
+ float block_amax = 0.0f;
+ float block_amax_gate = 0.0f;
+ const size_t stride = cols * 2;
+
+ // Find the absolute maximum value in the block
+ for (size_t i = i_min; i < i_max; ++i) {
+ for (size_t j = j_min; j < j_max; ++j) {
+ float silu_elt = static_cast(input[i * stride + j]);
+ float gate_elt = static_cast(input[i * stride + cols + j]);
+ float gated_amax_act = 0;
+ float gated_amax_gate = 0;
+
+ if constexpr (IS_DGATED) {
+ const float grad_elt = static_cast(grad[i * cols + j]);
+ const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
+ const float after_dgate = silu(silu_elt) * grad_elt;
+ gated_amax_act = abs(after_dsilu);
+ gated_amax_gate = abs(after_dgate);
+ } else {
+ const float after_silu = silu(silu_elt) * gate_elt;
+ gated_amax_act = abs(after_silu);
+ }
+
+ if (gated_amax_act > block_amax) { block_amax = gated_amax_act; }
+ if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; }
+ }
+ }
+
+ const fp8e8m0 biased_exponent = float_to_e8m0(block_amax *
+ Quantized_Limits::max_reciprocal());
+ const float scale_reciprocal = exp2f_rcp(biased_exponent);
+ output_scales[scale_idx] = biased_exponent;
+ float scale_reciprocal_gate = 1;
+ if constexpr (IS_DGATED) {
+ const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate *
+ Quantized_Limits::max_reciprocal());
+ scale_reciprocal_gate = exp2f_rcp(biased_exponent);
+ output_scales[scale_idx_gate] = biased_exponent;
+ }
+
+
+ // Quantize elements in the block
+ for (size_t i = i_min; i < i_max; ++i) {
+ for (size_t j = j_min; j < j_max; ++j) {
+ float silu_elt = static_cast(input[i * stride + j]);
+ float gate_elt = static_cast(input[i * stride + cols + j]);
+
+ if constexpr (IS_DGATED) {
+ const float grad_elt = static_cast(grad[i * cols + j]);
+ const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
+ const float after_dgate = silu(silu_elt) * grad_elt;
+ output[i * stride + j] = static_cast(after_dsilu * scale_reciprocal);
+ output[i * stride + cols + j] = static_cast(after_dgate *
+ scale_reciprocal_gate);
+ } else {
+ const float after_silu = silu(silu_elt) * gate_elt;
+ output[i * cols + j] = static_cast(after_silu * scale_reciprocal);
+ }
+
+ }
+ }
+ thread_amax = std::max(thread_amax, block_amax);
+ thread_amax = std::max(thread_amax, block_amax_gate);
+}
+
+template
+void compute_ref_x1(const IType* grad,
+ const IType* input,
+ OType* output,
+ fp8e8m0* output_scales,
+ float& ref_amax,
+ const size_t rows,
+ const size_t cols,
+ const size_t block_size_Y,
+ const size_t block_size_X,
+ const size_t scales_stride) {
+ const size_t tile_size_Y = std::max(32lu, block_size_Y);
+ const size_t tile_size_X = std::max(64lu, block_size_X);
+ const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
+ const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
+ const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
+ const size_t blocks_per_tile_X = tile_size_X / block_size_X;
+
+ float amax = 0;
+ #pragma omp parallel reduction(max: amax) proc_bind(spread)
+ {
+ float thread_amax = 0;
+ #pragma omp for schedule(static)
+ for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
+ const size_t tile_Y = t / tiles_num_X;
+ const size_t tile_X = t % tiles_num_X;
+ const size_t tile_offset_Y = tile_Y * tile_size_Y;
+ const size_t tile_offset_X = tile_X * tile_size_X;
+
+ for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) {
+ const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii;
+ const size_t block_offset_Y = ii * block_size_Y;
+ const size_t i_min = tile_offset_Y + block_offset_Y;
+ if (i_min >= rows) continue;
+ const size_t i_max = std::min(i_min + block_size_Y, rows);
+
+ for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) {
+ const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
+ const size_t block_offset_X = jj * block_size_X;
+ const size_t j_min = tile_offset_X + block_offset_X;
+ if (j_min >= cols) continue;
+ const size_t j_max = std::min(j_min + block_size_X, cols);
+
+ const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X;
+ const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X +
+ cols / block_size_X;
+ scale_block(
+ grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate,
+ thread_amax, i_min, i_max, j_min, j_max, cols);
+ }
+ }
+ }
+ if (thread_amax > amax) {
+ amax = thread_amax;
+ }
+ }
+ ref_amax = amax;
+}
+
+template
+void compute_ref_x2(const IType* grad,
+ const IType* input,
+ OType* output_rowwise,
+ OType* output_colwise,
+ fp8e8m0* scales_rowwise,
+ fp8e8m0* scales_colwise,
+ float& ref_amax,
+ const size_t rows,
+ const size_t cols,
+ const size_t block_size_Y,
+ const size_t block_size_X,
+ const size_t scales_stride_rowwise,
+ const size_t scales_stride_colwise) {
+ compute_ref_x1(
+ grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise);
+ compute_ref_x1(
+ grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise);
+}
+
+/**
+ * Scaling along single dimension (either rows or columns)
+ * Produces one set of output data and the corresponding data of the fused operation (dbias):
+ * 1) Scaled rows + row-wise scaling factors
+ * OR
+ * 2) Scaled columns + column-wise scaling factors
+ */
+template
+void performTest_x1(const size_t rows,
+ const size_t cols,
+ const size_t block_size_rows,
+ const size_t block_size_cols,
+ InputsFillCase fill_case) {
+ using namespace test;
+ using EncodingType = fp32;
+ DType itype = TypeInfo::dtype;
+ DType otype = TypeInfo::dtype;
+
+ const bool rowwise = (block_size_rows == 1) && (block_size_cols == 32);
+ const bool colwise = (block_size_rows == 32) && (block_size_cols == 1);
+ NVTE_CHECK(rowwise || colwise);
+
+ // std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl;
+ // std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl;
+ // std::cout << "blocks_Y: " << blocks_Y << std::endl;
+ // std::cout << "blocks_X: " << blocks_X << std::endl;
+ // std::cout << "scales_stride: " << scales_stride << std::endl;
+
+ Tensor grad("grad", { rows, cols }, itype);
+ Tensor input("input", { rows, cols * 2 }, itype);
+
+ const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
+
+ const std::array scale_dims = get_scale_tensor_dims(rows, output_cols, block_size_rows,
+ block_size_cols);
+
+ const size_t unpadded_blocks_Y = scale_dims[0];
+ const size_t unpadded_blocks_X = scale_dims[1];
+ const size_t blocks_Y = scale_dims[2];
+ const size_t blocks_X = scale_dims[3];
+ const size_t scales_stride = blocks_X;
+
+ Tensor output("output", std::vector{ rows, output_cols }, otype,
+ rowwise, colwise, NVTE_MXFP8_1D_SCALING);
+
+ std::unique_ptr ref_output = std::make_unique(rows * output_cols);
+ std::unique_ptr ref_output_scales = std::make_unique(blocks_Y * blocks_X);
+
+ for (size_t i = 0; i < blocks_Y * blocks_X; ++i) {
+ ref_output_scales[i] = 0;
+ }
+
+ // fillCase(&grad, fill_case);
+ if constexpr (IS_DGATED) {
+ fillUniform(&grad);
+ }
+ fillUniform(&input);
+
+ if constexpr (IS_DGATED) {
+ nvte_dswiglu(grad.data(), input.data(), output.data(), 0);
+ } else {
+ nvte_swiglu(input.data(), output.data(), 0);
+ }
+ cudaDeviceSynchronize();
+
+ auto err = cudaGetLastError();
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+
+ float ref_amax = 0;
+ compute_ref_x1