Commit 5b6ef054 authored by yuguo's avatar yuguo
Browse files
parents 76060570 a7eeb28b
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
pip3 install wheel
cd $TE_PATH
pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax
VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
# Core wheel.
NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel
wheel unpack dist/*
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel
cd transformer_engine/jax
NVTE_RELEASE_BUILD=1 python3 setup.py sdist
pip3 install dist/*
cd $TE_PATH
pip3 install dist/*.whl --no-deps
python3 $TE_PATH/tests/jax/test_sanity_import.py
{
"initial_year": 2022,
"copyright": "Copyright (c) <YEAR>, NVIDIA CORPORATION & AFFILIATES. All rights reserved.",
"license": "See LICENSE for license information.",
"exclude": ["3rdparty",
"Dockerfile",
"Dockerfile.base",
"Dockerfile.qa",
"Dockerfile.devel",
"Dockerfile.docs",
"docker-build.sh",
".png",
".ipynb",
"docs/Makefile",
"layout.html",
"LICENSE",
"VERSION",
"Doxyfile",
"pylintrc",
".json",
".md",
".txt"
],
"exclude_copyright": [],
"copyright_only": false
}
#!/usr/bin/env python3
# coding: utf-8
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import sys
import json
import datetime
if len(sys.argv) < 2:
print("Usage: python3 copyright_checker.py <path>")
path = sys.argv[1]
config_path = os.path.dirname(os.path.realpath(__file__)) + "/config.json"
class bcolors:
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
def print_ok(msg):
print(f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}")
def print_fail(msg):
print(f"{bcolors.FAIL}{msg}{bcolors.ENDC}")
def print_warn(msg):
print(f"{bcolors.WARNING}{msg}{bcolors.ENDC}")
with open(config_path, "r") as f:
c = json.load(f)
current_year = datetime.date.today().year
if c["initial_year"] == current_year:
year_string = str(current_year)
else:
year_string = str(c["initial_year"]) + "-" + str(current_year)
copyright_string = c["copyright"].replace("<YEAR>", year_string)
license = c["license"].split("\n")
excludes = c["exclude"]
root_path = os.path.abspath(path)
copyright_only = c["copyright_only"]
exclude_copyright = c["exclude_copyright"]
has_gitignore = os.path.exists(root_path + "/.gitignore")
def strip_star_slash(s):
ret = s
if ret.startswith("*"):
ret = ret[1:]
if ret.endswith("/"):
ret = ret[:-1]
return ret
if has_gitignore:
with open(root_path + "/.gitignore", "r") as f:
for line in f.readlines():
excludes.append(strip_star_slash(line.strip()))
def get_file_type(path):
ext = {
"c": ["c", "cpp", "cu", "h", "cuh"],
"py": ["py"],
"rst": ["rst"],
"txt": ["txt"],
"cfg": ["cfg"],
"sh": ["sh"],
"md": ["md"],
}
tmp = path.split(".")
for filetype, ext_list in ext.items():
if tmp[-1] in ext_list:
return filetype
return "unknown"
success = True
def check_file(path):
global success
N = 10
ftype = get_file_type(path)
if ftype == "unknown":
print_warn("Unknown filetype")
return
check_copyright = True
for e in exclude_copyright:
if path.endswith(e):
check_copyright = False
with open(path, "r") as f:
copyright_found = False
license_found = True
try:
if check_copyright:
for _ in range(N):
line = f.readline()
if line.find(copyright_string) != -1:
copyright_found = True
break
if not copyright_only:
first_license_line = True
for l in license:
if first_license_line:
# may skip some lines
first_license_line = False
for _ in range(N):
line = f.readline()
if line.find(l) != -1:
break
else:
line = f.readline()
if line.find(l) == -1:
license_found = False
break
except:
pass
finally:
if not copyright_found:
print_fail("No copyright found!")
success = False
if not license_found:
print_fail("No license found!")
success = False
if copyright_found and license_found:
print_ok("OK")
for root, dirs, files in os.walk(root_path):
print(f"Entering {root}")
hidden = [d for d in dirs if d.startswith(".")] + [f for f in files if f.startswith(".")]
all_excludes = excludes + hidden
to_remove = []
for d in dirs:
d_path = root + "/" + d
for e in all_excludes:
if d_path.endswith(e):
to_remove.append(d)
for f in files:
f_path = root + "/" + f
for e in all_excludes:
if f_path.endswith(e):
to_remove.append(f)
for d in to_remove:
if d in dirs:
dirs.remove(d)
if d in files:
files.remove(d)
for filename in files:
print(f"Checking {filename}")
check_file(os.path.abspath(root + "/" + filename))
if not success:
raise Exception("Some copyrights/licenses are missing!")
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
python3 $TE_PATH/qa/L0_license/copyright_checker.py $TE_PATH
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
pip3 install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
echo "Checking common API headers"
python3 -m cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include
echo "Checking C++ files"
python3 -m cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common
python3 -m cpplint --recursive transformer_engine/pytorch
fi
if [ -z "${CPP_ONLY}" ]
then
cd $TE_PATH
echo "Checking Python files"
python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch
fi
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -x
: ${TE_PATH:=/opt/transformerengine}
pip3 install pytest==8.2.1
FAIL=0
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || FAIL=1
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || FAIL=1
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1
exit $FAIL
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
pip3 install wheel
cd $TE_PATH
pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch
VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
# Core wheel.
NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel
wheel unpack dist/*
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel
cd transformer_engine/pytorch
NVTE_RELEASE_BUILD=1 python3 setup.py sdist
pip3 install dist/*
cd $TE_PATH
pip3 install dist/*.whl --no-deps
python3 $TE_PATH/tests/pytorch/test_sanity_import.py
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
: ${TE_PATH:=/opt/transformerengine}
pip3 install pytest==8.2.1
FAIL=0
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || FAIL=1
# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || FAIL=1 ### TODO Debug UB support with te.Sequential
python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || FAIL=1
exit $FAIL
Megatron-LM
vocab.json
\ No newline at end of file
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
# Paths
: ${TE_PATH:=/opt/transformerengine}
: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM}
# Check whether FP8 is supported
DEVICE_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n 1 | sed 's/[^0-9]//g')
if [[ ${DEVICE_ARCH} -ge 89 ]]; then
WITH_FP8=1
fi
# Download Megatron-LM if needed
if [ ! -d "${MCORE_PATH}" ]; then
pushd $(dirname ${MCORE_PATH})
git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM
popd
fi
# Create mock vocab
VOCAB_FILE=${TE_PATH}/qa/L1_pytorch_mcore_integration/vocab.json
printf "" > ${VOCAB_FILE}
printf "{" >> ${VOCAB_FILE}
printf "\"<|endoftext|>\": 0" >> ${VOCAB_FILE}
seq 1 4095 | awk '{ printf(", \"%d\": %d", $1, $1) }' >> ${VOCAB_FILE}
printf "}" >> ${VOCAB_FILE}
# Megatron-LM invocation
COMMAND="
NVTE_TORCH_COMPILE=0
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0
NVTE_FLASH_ATTN=1
NVTE_FWD_LAYERNORM_SM_MARGIN=0
NVTE_BWD_LAYERNORM_SM_MARGIN=0
CUDA_DEVICE_MAX_CONNECTIONS=1
NVTE_BIAS_GELU_NVFUSION=0
NVTE_BIAS_DROPOUT_FUSION=0
python3
-m torch.distributed.launch
--use_env
--nnodes=1
--nproc_per_node=1
${MCORE_PATH}/pretrain_gpt.py
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 1
--use-cpu-initialization
--num-layers 2
--hidden-size 128
--num-attention-heads 8
--seq-length 128
--max-position-embeddings 128
--micro-batch-size 1
--global-batch-size 8
--train-iters 10
--eval-iters 10
--lr 1e-4
--mock-data
--vocab-file ${VOCAB_FILE}
--merge-file ${TE_PATH}/qa/L1_pytorch_mcore_integration/merges.txt
--transformer-impl transformer_engine
${WITH_FP8:+--fp8-format hybrid}
"
COMMAND=$(echo "${COMMAND}" | tr '\n' ' ')
# Launch Megatron-LM
bash -c "${COMMAND}"
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -x
: ${THUNDER_PATH:=/opt/pytorch/lightning-thunder}
pip3 install pytest==8.1.1 pytest-benchmark==5.1.0
python3 -m pytest -v -s ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py
# Check return code
# Note: Return code 5 is fine. Lightning tests are skipped on systems
# without FP8 support and Pytest returns 5 if no tests are run.
RC=$?
if [ ${RC} -eq 5 ]; then
RC=0
fi
exit ${RC}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: ${TE_PATH:=/opt/transformerengine}
pip3 install pytest==8.2.1
# Limit parallel build jobs to avoid overwhelming system resources
export MAX_JOBS=4
# Iterate over Flash Attention versions
sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"`
if [ $sm_arch -gt 90 ]
then
FA_versions=(2.7.3)
else
FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
fi
for fa_version in "${FA_versions[@]}"
do
# Build Flash Attention
if [ "${fa_version}" \< "3.0.0" ]
then
pip3 install flash-attn==${fa_version}
else
pip3 install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
python_path=`python3 -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flashattn_hopper
wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py
fi
# Run tests
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
done
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Utility file to run pre-commit hooks locally
# Usage: bash qa/format.sh
set -e
: "${TE_PATH:=.}"
cd $TE_PATH
pip3 install pre-commit
python3 -m pre_commit run --all-files
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Installation script."""
import os
import sys
import time
from pathlib import Path
from typing import List, Tuple
import setuptools
from wheel.bdist_wheel import bdist_wheel
from build_tools.build_ext import CMakeExtension, get_build_ext
from build_tools.te_version import te_version
from build_tools.utils import (
cuda_archs,
found_cmake,
found_ninja,
found_pybind11,
get_frameworks,
install_and_import,
remove_dups,
uninstall_te_wheel_packages,
)
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
from setuptools.command.build_ext import build_ext as BuildExtension
os.environ["NVTE_PROJECT_BUILDING"] = "1"
if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension
CMakeBuildExtension = get_build_ext(BuildExtension)
archs = cuda_archs()
class TimedBdist(bdist_wheel):
"""Helper class to measure build time"""
def run(self):
start_time = time.perf_counter()
super().run()
total_time = time.perf_counter() - start_time
print(f"Total time for bdist_wheel: {total_time:.2f} seconds")
def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library"""
cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")
# Project directory root
root_path = Path(__file__).resolve().parent
return CMakeExtension(
name="transformer_engine",
cmake_path=root_path / Path("transformer_engine/common"),
cmake_flags=cmake_flags,
)
def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
"""Setup Python dependencies
Returns dependencies for build, runtime, and testing.
"""
# Common requirements
setup_reqs: List[str] = []
install_reqs: List[str] = [
"pydantic",
"importlib-metadata>=1.0",
"packaging",
]
test_reqs: List[str] = ["pytest>=8.2.1"]
# Requirements that may be installed outside of Python
if not found_cmake():
setup_reqs.append("cmake>=3.21")
if not found_ninja():
setup_reqs.append("ninja")
if not found_pybind11():
setup_reqs.append("pybind11")
# Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch>=2.1"])
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"])
# test_reqs.extend(["numpy", "praxis"])
test_reqs.extend(["numpy"])
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
if __name__ == "__main__":
__version__ = te_version()
with open("README.rst", encoding="utf-8") as f:
long_description = f.read()
# Settings for building top level empty package for dependency management.
if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
assert bool(
int(os.getenv("NVTE_RELEASE_BUILD", "0"))
), "NVTE_RELEASE_BUILD env must be set for metapackage build."
ext_modules = []
cmdclass = {}
package_data = {}
include_package_data = False
setup_requires = []
install_requires = ([f"transformer_engine_cu12=={__version__}"],)
extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
}
else:
setup_requires, install_requires, test_requires = setup_requirements()
ext_modules = [setup_common_extension()]
cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
package_data = {"": ["VERSION.txt"]}
include_package_data = True
extras_require = {"test": test_requires}
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
# Remove residual FW packages since compiling from source
# results in a single binary with FW extensions included.
uninstall_te_wheel_packages()
if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension
ext_modules.append(
setup_pytorch_extension(
"transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine",
)
)
if "jax" in frameworks:
from build_tools.jax import setup_jax_extension
ext_modules.append(
setup_jax_extension(
"transformer_engine/jax/csrc",
current_file_path / "transformer_engine" / "jax" / "csrc",
current_file_path / "transformer_engine",
)
)
# Configure package
setuptools.setup(
name="transformer_engine",
version=__version__,
packages=setuptools.find_packages(
include=[
"transformer_engine",
"transformer_engine.*",
"transformer_engine/build_tools",
],
),
extras_require=extras_require,
description="Transformer acceleration library",
long_description=long_description,
long_description_content_type="text/x-rst",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=">=3.8, <3.13",
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
setup_requires=setup_requires,
install_requires=install_requires,
license_files=("LICENSE",),
include_package_data=include_package_data,
package_data=package_data,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90)
endif()
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
project(transformer_engine_tests LANGUAGES CUDA CXX)
add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest)
enable_testing()
include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
if(NOT DEFINED TE_LIB_PATH)
execute_process(COMMAND bash -c "pip3 show transformer-engine | grep Location | cut -d ' ' -f 2 | tr -d '\n'"
OUTPUT_VARIABLE TE_LIB_PATH)
endif()
find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/transformer_engine" ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED)
message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include)
include_directories(../../transformer_engine/common)
include_directories(${CMAKE_SOURCE_DIR})
find_package(CUDAToolkit REQUIRED)
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
add_subdirectory(operator)
add_subdirectory(util)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
add_executable(test_operator
test_cast.cu
test_cast_current_scaling.cu
test_cast_dbias.cu
test_cast_dbias_dgelu.cu
test_cast_gated_swiglu.cu
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
test_cast_mxfp8.cu
test_dequantize_mxfp8.cu
test_transpose.cu
test_cast_transpose.cu
test_cast_transpose_current_scaling.cu
test_cast_transpose_dbias.cu
test_cast_transpose_dbias_dgelu.cu
test_cast_transpose_dgeglu.cu
test_act.cu
test_normalization.cu
test_normalization_mxfp8.cu
test_multi_cast_transpose.cu
test_multi_padding.cu
test_causal_softmax.cu
test_swizzle.cu
../test_common.cu)
find_package(OpenMP REQUIRED)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX)
target_compile_options(test_operator PRIVATE -O2 -fopenmp)
include(GoogleTest)
gtest_discover_tests(test_operator DISCOVERY_TIMEOUT 600)
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <type_traits>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
using namespace transformer_engine;
template <float (*act)(const float), typename IT, typename OT, typename CT>
void compute_ref_act_cast(const IT *input_h,
OT *output_h,
const CT scale,
CT *amax_h,
const size_t N,
const size_t H) {
CT amax = 0.;
#pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]);
elt = act(elt);
output_h[i * H + j] = static_cast<OT>(scale * elt);
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
}
}
*amax_h = amax;
}
template <float (*dact)(const float), typename IT, typename OT>
void compute_ref_dact_cast(const IT *input_h,
const IT *grad_h,
OT *output_h,
const size_t N,
const size_t H) {
using CT = float;
#pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]);
elt = dact(elt);
CT grad = static_cast<CT>(grad_h[i * H + j]);
output_h[i * H + j] = static_cast<OT>(grad * elt);
}
}
}
template <float (*act)(const float), typename IT, typename OT, typename CT>
void compute_ref_glu_act_cast(const IT *input_h, OT *output_h, const CT scale, CT *amax_h,
const size_t N, const size_t H) {
CT amax = 0.;
const int col = H * 2;
#pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT gelu_elt = static_cast<CT>(input_h[i * col + j]);
gelu_elt = act(gelu_elt);
CT gate_elt = static_cast<CT>(input_h[i * col + H + j]);
CT elt = gelu_elt * gate_elt;
output_h[i * H + j] = static_cast<OT>(scale * elt);
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
}
}
*amax_h = amax;
}
template <float (*dact)(const float), float (*act)(const float),
typename IT, typename OT>
void compute_ref_dglu_act_cast(const IT *input_h, const IT *grad_h, OT *output_h,
const size_t N, const size_t H) {
const int col = H * 2;
using CT = float;
#pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT grad = static_cast<CT>(grad_h[i * H + j]);
CT gelu_elt = static_cast<CT>(input_h[i * col + j]);
CT gate_elt = static_cast<CT>(input_h[i * col + H + j]);
output_h[i * col + H + j] = static_cast<OT>(grad * act(gelu_elt));
gelu_elt = dact(gelu_elt);
CT elt = gelu_elt * gate_elt;
output_h[i * col + j] = static_cast<OT>(grad * elt);
}
}
}
template <float (*ref_act)(const float),
float (*ref_dact)(const float),
void (*nvte_act)(const NVTETensor, NVTETensor, cudaStream_t),
void (*nvte_dact)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t),
typename IType, typename OType>
void performTest(const size_t N, const size_t H) {
using namespace test;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor input("input", { N, H }, itype);
Tensor output("output", { N, H }, otype);
Tensor igrad("igrad", { N, H }, itype);
Tensor ograd("ograd", { N, H }, itype);
fillUniform(&input);
fillUniform(&ograd);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_igrad = std::make_unique<IType[]>(N*H);
nvte_act(input.data(), output.data(), 0);
float ref_amax;
compute_ref_act_cast<ref_act>(input.rowwise_cpu_dptr<IType>(), ref_output.get(),
output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_act", output, ref_output.get(), atol, rtol);
nvte_dact(ograd.data(), input.data(), igrad.data(), 0);
compute_ref_dact_cast<ref_dact>(input.rowwise_cpu_dptr<IType>(), ograd.rowwise_cpu_dptr<IType>(),
ref_igrad.get(), N, H);
cudaDeviceSynchronize();
err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
{
auto [atol, rtol] = getTolerances(otype);
compareResults("igrad_act", igrad, ref_igrad.get(), atol, rtol);
}
}
template <float (*ref_act)(const float),
float (*ref_dact)(const float),
void (*nvte_act)(const NVTETensor, NVTETensor, cudaStream_t),
void (*nvte_dact)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t),
typename IType, typename OType>
void performTestGLU(const size_t N, const size_t H) {
using namespace test;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor input("input", {N, H * 2}, itype);
Tensor output("output", {N, H}, otype);
Tensor igrad("igrad", { N, H * 2 }, itype);
Tensor ograd("ograd", { N, H }, itype);
fillUniform(&input);
fillUniform(&ograd);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N * H);
std::unique_ptr<IType[]> ref_igrad = std::make_unique<IType[]>(2 * N * H);
nvte_act(input.data(), output.data(), 0);
float ref_amax;
compute_ref_glu_act_cast<ref_act>(input.rowwise_cpu_dptr<IType>(), ref_output.get(),
output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol, rtol] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol, rtol);
if (output.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
const float ref_scale = 1.f / output.scale();
compareResults("scale_inv", *output.rowwise_cpu_scale_inv_ptr<float>(), ref_scale, atol, rtol);
}
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol);
nvte_dact(ograd.data(), input.data(), igrad.data(), 0);
compute_ref_dglu_act_cast<ref_dact, ref_act>(input.rowwise_cpu_dptr<IType>(), ograd.rowwise_cpu_dptr<IType>(),
ref_igrad.get(), N, H);
cudaDeviceSynchronize();
err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
{
auto [atol, rtol] = getTolerances(otype);
compareResults("igrad_act", igrad, ref_igrad.get(), atol, rtol);
}
}
class ActTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>>> {};
TEST_P(ActTestSuite, TestGELU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<gelu, dgelu, nvte_gelu, nvte_dgelu,
InputType, OutputType>(size.first, size.second);
);
);
}
TEST_P(ActTestSuite, TestSILU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<silu, dsilu, nvte_silu, nvte_dsilu,
InputType, OutputType>(size.first, size.second);
);
);
}
TEST_P(ActTestSuite, TestRELU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<relu, drelu, nvte_relu, nvte_drelu,
InputType, OutputType>(size.first, size.second);
);
);
}
TEST_P(ActTestSuite, TestQGELU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<qgelu, dqgelu, nvte_qgelu, nvte_dqgelu,
InputType, OutputType>(size.first, size.second);
);
);
}
TEST_P(ActTestSuite, TestSRELU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<srelu, dsrelu, nvte_srelu, nvte_dsrelu,
InputType, OutputType>(size.first, size.second);
);
);
}
TEST_P(ActTestSuite, TestGeGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestGLU<gelu, dgelu, nvte_geglu, nvte_dgeglu, InputType,
OutputType>(size.first, size.second);););
}
TEST_P(ActTestSuite, TestReGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestGLU<relu, drelu, nvte_reglu, nvte_dreglu, InputType,
OutputType>(size.first, size.second);););
}
TEST_P(ActTestSuite, TestSwiGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestGLU<silu, dsilu, nvte_swiglu, nvte_dswiglu, InputType,
OutputType>(size.first, size.second);););
}
TEST_P(ActTestSuite, TestQGeGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestGLU<qgelu, dqgelu, nvte_qgeglu, nvte_dqgeglu, InputType,
OutputType>(size.first, size.second);););
}
TEST_P(ActTestSuite, TestSReGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestGLU<srelu, dsrelu, nvte_sreglu, nvte_dsreglu, InputType,
OutputType>(size.first, size.second);););
}
namespace {
std::vector<std::pair<size_t, size_t>> act_test_cases = {{2048, 12288},
{768, 2816},
{256, 65536},
{65536, 128},
{256, 256},
{257, 259},
{128, 128+1}};
} // namespace
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
ActTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(act_test_cases)),
[](const testing::TestParamInfo<ActTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const InputType *data, OutputType *output_c,
const size_t size,
float *amax, float scale) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < size; ++i) {
compute_t current = static_cast<compute_t>(data[i]);
current_max = fmaxf(current_max, fabsf(current));
output_c[i] = OutputType(scale * current);
}
*amax = current_max;
}
// delayed tensor scaling test
template <typename InputType, typename OutputType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
const size_t full_size = product(shape);
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", shape, itype);
Tensor output_c("output_c", shape, otype);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(full_size);
fillUniform(&input);
setRandomScale(&output_c);
nvte_quantize(input.data(), output_c.data(), 0);
float ref_amax;
compute_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output_c.get(),
full_size, &ref_amax, output_c.scale());
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
}
std::vector<std::vector<size_t>> test_cases = {
{16},
{16000},
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{65536, 160},
{16384, 1616},
{1, 128},
{1, 1296},
{1, 16},
{5, 160},
{5, 4, 3, 160},
{217, 256},
};
} // namespace
class CastTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(CastTestSuite, TestCast) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
// delayed tensor scaling
performTest<InputType, OutputType>(size);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment