"qa/git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "1bbeab1c563e7b8551804cb5af0847d277e22951"
Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
...@@ -24,10 +24,10 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa ...@@ -24,10 +24,10 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
# Make encoder tests to have run-to-run deterministic to have the stable CI results # Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" # python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
wait # wait
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" # python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
wait # wait
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" . $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
......
...@@ -27,9 +27,6 @@ mkdir -p "$XML_LOG_DIR" ...@@ -27,9 +27,6 @@ mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*"
# Test without custom calls
NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
...@@ -37,6 +34,9 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa ...@@ -37,6 +34,9 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
# Make encoder tests to have run-to-run deterministic to have the stable CI results # Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES" echo "Error: some sub-tests failed: $FAILED_CASES"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
: ${TE_PATH:=/opt/transformerengine}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
: ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/}
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
FAIL=0
pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
exit $FAIL
...@@ -20,5 +20,5 @@ if [ -z "${CPP_ONLY}" ] ...@@ -20,5 +20,5 @@ if [ -z "${CPP_ONLY}" ]
then then
cd $TE_PATH cd $TE_PATH
echo "Checking Python files" echo "Checking Python files"
python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch transformer_engine/debug
fi fi
...@@ -47,6 +47,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entro ...@@ -47,6 +47,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entro
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
......
...@@ -20,6 +20,7 @@ FAILED_CASES="" ...@@ -20,6 +20,7 @@ FAILED_CASES=""
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
...@@ -30,6 +31,19 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use ...@@ -30,6 +31,19 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
# debug tests
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
exit 1 exit 1
......
...@@ -25,18 +25,18 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" ...@@ -25,18 +25,18 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
# Test without custom calls
NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements" pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results # Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS_RE="" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES" echo "Error: some sub-tests failed: $FAILED_CASES"
......
...@@ -19,11 +19,7 @@ from build_tools.te_version import te_version ...@@ -19,11 +19,7 @@ from build_tools.te_version import te_version
from build_tools.utils import ( from build_tools.utils import (
rocm_build, rocm_build,
cuda_archs, cuda_archs,
found_cmake,
found_ninja,
found_pybind11,
get_frameworks, get_frameworks,
install_and_import,
remove_dups, remove_dups,
) )
...@@ -38,7 +34,6 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1" ...@@ -38,7 +34,6 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1"
if "pytorch" in frameworks: if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks: elif "jax" in frameworks:
install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension from pybind11.setup_helpers import build_ext as BuildExtension
...@@ -87,6 +82,11 @@ def setup_common_extension() -> CMakeExtension: ...@@ -87,6 +82,11 @@ def setup_common_extension() -> CMakeExtension:
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))): if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON") cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")
# Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
if nvte_cmake_extra_args:
cmake_flags.extend(nvte_cmake_extra_args.split())
# Project directory root # Project directory root
root_path = Path(__file__).resolve().parent root_path = Path(__file__).resolve().parent
if rocm_build(): if rocm_build():
...@@ -102,22 +102,13 @@ def setup_common_extension() -> CMakeExtension: ...@@ -102,22 +102,13 @@ def setup_common_extension() -> CMakeExtension:
) )
def setup_requirements() -> Tuple[List[str], List[str], List[str]]: def setup_requirements() -> Tuple[List[str], List[str]]:
"""Setup Python dependencies """Setup Python dependencies
Returns dependencies for build, runtime, and testing. Returns dependencies for runtime and testing.
""" """
# Common requirements # Common requirements
setup_reqs: List[str] = [
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
]
install_reqs: List[str] = [ install_reqs: List[str] = [
"pydantic", "pydantic",
"importlib-metadata>=1.0", "importlib-metadata>=1.0",
...@@ -125,32 +116,20 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -125,32 +116,20 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
] ]
test_reqs: List[str] = ["pytest>=8.2.1"] test_reqs: List[str] = ["pytest>=8.2.1"]
# Requirements that may be installed outside of Python
if not found_cmake():
setup_reqs.append("cmake>=3.21")
if not found_ninja():
setup_reqs.append("ninja")
if not found_pybind11():
setup_reqs.append("pybind11")
# Framework-specific requirements # Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
setup_reqs.extend(["torch>=2.1"]) from build_tools.pytorch import install_requirements, test_requirements
install_reqs.extend(["torch>=2.1"])
# install_reqs.append( install_reqs.extend(install_requirements())
# "nvdlfw-inspect @" test_reqs.extend(test_requirements())
# " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
# )
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision"])
if "jax" in frameworks: if "jax" in frameworks:
setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"]) from build_tools.jax import install_requirements, test_requirements
install_reqs.extend(["jax", "flax>=0.7.1"])
test_reqs.extend(["numpy"]) install_reqs.extend(install_requirements())
test_reqs.extend(test_requirements())
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
if __name__ == "__main__": if __name__ == "__main__":
...@@ -167,14 +146,13 @@ if __name__ == "__main__": ...@@ -167,14 +146,13 @@ if __name__ == "__main__":
ext_modules = [] ext_modules = []
package_data = {} package_data = {}
include_package_data = False include_package_data = False
setup_requires = []
install_requires = ([f"transformer_engine_cu12=={__version__}"],) install_requires = ([f"transformer_engine_cu12=={__version__}"],)
extras_require = { extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"],
} }
else: else:
setup_requires, install_requires, test_requires = setup_requirements() install_requires, test_requires = setup_requirements()
ext_modules = [setup_common_extension()] ext_modules = [setup_common_extension()]
package_data = {"": ["VERSION.txt"]} package_data = {"": ["VERSION.txt"]}
include_package_data = True include_package_data = True
...@@ -219,15 +197,8 @@ if __name__ == "__main__": ...@@ -219,15 +197,8 @@ if __name__ == "__main__":
long_description_content_type="text/x-rst", long_description_content_type="text/x-rst",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=">=3.8, <3.13", python_requires=">=3.8",
classifiers=[ classifiers=["Programming Language :: Python :: 3"],
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
setup_requires=setup_requires,
install_requires=install_requires, install_requires=install_requires,
license_files=("LICENSE",), license_files=("LICENSE",),
include_package_data=include_package_data, include_package_data=include_package_data,
......
...@@ -375,7 +375,7 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = { ...@@ -375,7 +375,7 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = {
{256, 256}, {256, 256},
{993, 512}, {993, 512},
{768, 1024}, {768, 1024},
{65536, 128}, {65504, 128},
{16384, 1632}, {16384, 1632},
}; };
......
...@@ -71,7 +71,8 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const ...@@ -71,7 +71,8 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
// Remove the use_cudnn check here when it is supported by both backends. // Remove the use_cudnn check here when it is supported by both backends.
const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){ if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3> ||
std::is_same_v<InputType, fp4e2m1>){
compute_t g = static_cast<compute_t>(gamma); compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) { if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f); g += static_cast<compute_t>(1.f);
......
...@@ -45,7 +45,7 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) { ...@@ -45,7 +45,7 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
return true; return true;
} }
size_t typeToSize(DType type) { size_t typeToNumBits(DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
{ {
return TypeInfo<T>::size; return TypeInfo<T>::size;
...@@ -62,7 +62,8 @@ const std::string &typeName(DType type) { ...@@ -62,7 +62,8 @@ const std::string &typeName(DType type) {
{DType::kBFloat16, "bfloat16"}, {DType::kBFloat16, "bfloat16"},
{DType::kFloat8E4M3, "float8e4m3"}, {DType::kFloat8E4M3, "float8e4m3"},
{DType::kFloat8E5M2, "float8e5m2"}, {DType::kFloat8E5M2, "float8e5m2"},
{DType::kFloat8E8M0, "float8e8m0"}}; {DType::kFloat8E8M0, "float8e8m0"},
{DType::kFloat4E2M1, "float4e2m1"}};
return name_map.at(type); return name_map.at(type);
} }
...@@ -109,9 +110,16 @@ size_t DIVUP(const size_t &x, const size_t &y){ ...@@ -109,9 +110,16 @@ size_t DIVUP(const size_t &x, const size_t &y){
struct scale_inv_meta { struct scale_inv_meta {
std::vector<size_t> shape; std::vector<size_t> shape;
DType type; DType type;
size_t type_size; size_t type_size_bits;
size_t bytes() const noexcept {
return (product(shape) * type_size_bits) / 8;
}
}; };
size_t bytes(const NVTEShape& shape, const DType type) {
return (product(shape) * typeToNumBits(type)) / 8;
}
NVTEShape convertShape(const std::vector<size_t>& s) { NVTEShape convertShape(const std::vector<size_t>& s) {
return nvte_make_shape(s.data(), s.size()); return nvte_make_shape(s.data(), s.size());
} }
...@@ -122,7 +130,7 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -122,7 +130,7 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret; scale_inv_meta ret;
ret.shape = {1}; ret.shape = {1};
ret.type = DType::kFloat32; ret.type = DType::kFloat32;
ret.type_size = sizeof(float); ret.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret, ret}; return {ret, ret};
} }
if (scaling_mode == NVTE_MXFP8_1D_SCALING) { if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
...@@ -152,8 +160,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -152,8 +160,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
} }
ret_rowwise.type = DType::kFloat8E8M0; ret_rowwise.type = DType::kFloat8E8M0;
ret_colwise.type = DType::kFloat8E8M0; ret_colwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size = sizeof(uint8_t); ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
ret_colwise.type_size = sizeof(uint8_t); ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
return {ret_rowwise, ret_colwise}; return {ret_rowwise, ret_colwise};
} }
...@@ -179,8 +187,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -179,8 +187,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
} }
ret_rowwise.type = DType::kFloat32; ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32; ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float); ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
ret_colwise.type_size = sizeof(float); ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret_rowwise, ret_colwise}; return {ret_rowwise, ret_colwise};
} }
...@@ -205,8 +213,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -205,8 +213,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
} }
ret_rowwise.type = DType::kFloat32; ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32; ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float); ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
ret_colwise.type_size = sizeof(float); ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret_rowwise, ret_colwise}; return {ret_rowwise, ret_colwise};
} }
...@@ -222,8 +230,7 @@ Tensor::Tensor(const std::string& name, ...@@ -222,8 +230,7 @@ Tensor::Tensor(const std::string& name,
gen_.seed(seed); gen_.seed(seed);
rowwise_ = rowwise; rowwise_ = rowwise;
columnwise_ = columnwise; columnwise_ = columnwise;
size_t s = typeToSize(type); size_t total_size = bytes(shape, type);
size_t total_size = product(shape) * s;
void *dptr_rowwise = nullptr; void *dptr_rowwise = nullptr;
void *dptr_columnwise = nullptr; void *dptr_columnwise = nullptr;
cpu_data_rowwise_ = nullptr; cpu_data_rowwise_ = nullptr;
...@@ -305,8 +312,8 @@ Tensor::Tensor(const std::string& name, ...@@ -305,8 +312,8 @@ Tensor::Tensor(const std::string& name,
} else { } else {
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(normalized_shape, tensor_.scaling_mode()); get_scales(normalized_shape, tensor_.scaling_mode());
auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto rowwise_scale_size = rowwise_scale_meta.bytes();
auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto columnwise_scale_size = colwise_scale_meta.bytes();
auto scale_shape = rowwise_scale_meta.shape; auto scale_shape = rowwise_scale_meta.shape;
auto columnwise_scale_shape = colwise_scale_meta.shape; auto columnwise_scale_shape = colwise_scale_meta.shape;
if (rowwise) { if (rowwise) {
...@@ -331,7 +338,7 @@ Tensor::Tensor(const std::string& name, ...@@ -331,7 +338,7 @@ Tensor::Tensor(const std::string& name,
void Tensor::to_cpu() const { void Tensor::to_cpu() const {
const NVTEShape s = tensor_.shape(); const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype()); const size_t size = bytes(s, tensor_.dtype());
if (rowwise_) { if (rowwise_) {
cudaMemcpy(cpu_data_rowwise_.get(), cudaMemcpy(cpu_data_rowwise_.get(),
tensor_.get_rowwise_data().data_ptr, tensor_.get_rowwise_data().data_ptr,
...@@ -360,14 +367,14 @@ void Tensor::to_cpu() const { ...@@ -360,14 +367,14 @@ void Tensor::to_cpu() const {
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode()); get_scales(s, tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
tensor_.get_rowwise_scale_inv().data_ptr, tensor_.get_rowwise_scale_inv().data_ptr,
scale_size, scale_size,
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
} }
if (columnwise_) { if (columnwise_) {
auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto scale_size = colwise_scale_meta.bytes();
cudaMemcpy(columnwise_scale_inv_cpu_data_.get(), cudaMemcpy(columnwise_scale_inv_cpu_data_.get(),
tensor_.get_columnwise_scale_inv().data_ptr, tensor_.get_columnwise_scale_inv().data_ptr,
scale_size, scale_size,
...@@ -378,34 +385,32 @@ void Tensor::to_cpu() const { ...@@ -378,34 +385,32 @@ void Tensor::to_cpu() const {
void Tensor::from_cpu() const { void Tensor::from_cpu() const {
const NVTEShape s = tensor_.shape(); const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype()); const size_t size = bytes(s, tensor_.dtype());
if (rowwise_) { if (rowwise_) {
cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size,
cpu_data_rowwise_.get(), size, cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
if (columnwise_) { if (columnwise_) {
cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
if (isFp8Type(dtype())) { if (isFp8Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (tensor_.amax() != nullptr){ if (tensor_.amax() != nullptr){
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpyHostToDevice);
} }
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpyHostToDevice);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode()); get_scales(s, tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
rowwise_scale_inv_cpu_data_.get(), scale_size, rowwise_scale_inv_cpu_data_.get(), scale_size,
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
if (columnwise_) { if (columnwise_) {
auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto scale_size = colwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr, cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr,
columnwise_scale_inv_cpu_data_.get(), scale_size, columnwise_scale_inv_cpu_data_.get(), scale_size,
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
...@@ -735,6 +740,19 @@ std::pair<double, double> getTolerances(const DType type) { ...@@ -735,6 +740,19 @@ std::pair<double, double> getTolerances(const DType type) {
template <typename T> template <typename T>
void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
// Check how many RNG calls are required to generate one uniform random value
int rng_calls_per_val = 0;
{
std::mt19937 gen1 = *gen, gen2 = *gen;
std::uniform_real_distribution<> dis(-2.0, 1.0);
const float _ = dis(gen1);
while (gen2 != gen1) {
auto _ = gen2();
++rng_calls_per_val;
}
}
// Generate uniform random values in parallel
#pragma omp parallel proc_bind(spread) #pragma omp parallel proc_bind(spread)
{ {
std::mt19937 gen_local = *gen; std::mt19937 gen_local = *gen;
...@@ -743,7 +761,7 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { ...@@ -743,7 +761,7 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
const int chunk_size = (size + threads_num - 1) / threads_num; const int chunk_size = (size + threads_num - 1) / threads_num;
const int idx_min = chunk_size * thread_ID; const int idx_min = chunk_size * thread_ID;
const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size)); const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size));
gen_local.discard(idx_min); gen_local.discard(idx_min * rng_calls_per_val);
std::uniform_real_distribution<> dis(-2.0, 1.0); std::uniform_real_distribution<> dis(-2.0, 1.0);
for (int i = idx_min; i < idx_max; ++i) { for (int i = idx_min; i < idx_max; ++i) {
...@@ -754,7 +772,7 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { ...@@ -754,7 +772,7 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
#endif #endif
} }
} }
gen->discard(size); gen->discard(size * rng_calls_per_val);
} }
void fillUniform(Tensor *t) { void fillUniform(Tensor *t) {
......
...@@ -10,11 +10,18 @@ ...@@ -10,11 +10,18 @@
#include <vector> #include <vector>
#include <array> #include <array>
#include <random> #include <random>
#include <cudaTypedefs.h>
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_fp8.h>
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif
#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include "util/logging.h" #include "util/logging.h"
...@@ -56,20 +63,32 @@ using fp8e4m3 = __nv_fp8_e4m3; ...@@ -56,20 +63,32 @@ using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
using fp8e8m0 = uint8_t; using fp8e8m0 = uint8_t;
using int8 = int8_t; using int8 = int8_t;
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
#endif
template <typename T>
struct BitsNumber;
#if FP4_TYPE_SUPPORTED
template <>
struct BitsNumber<fp4e2m1> {
static constexpr size_t num_bits = 4;
};
#endif
template <typename T>
struct BitsNumber {
static constexpr size_t num_bits = 8 * sizeof(T);
};
template <typename T> template <typename T>
struct TypeInfo{ struct TypeInfo {
using types = std::tuple<byte, #if FP4_TYPE_SUPPORTED
int16, using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, fp4e2m1>;
int32, #else
int64, using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, int8>;
fp32, #endif
fp16,
bf16,
fp8e4m3,
fp8e5m2,
fp8e8m0,
int8>;
template <typename U, DType current> template <typename U, DType current>
struct Helper { struct Helper {
...@@ -96,7 +115,7 @@ struct TypeInfo{ ...@@ -96,7 +115,7 @@ struct TypeInfo{
} }
constexpr static DType dtype = getType<T>(); constexpr static DType dtype = getType<T>();
constexpr static size_t size = sizeof(T); constexpr static size_t size = BitsNumber<T>::num_bits;;
}; };
class Tensor { class Tensor {
...@@ -418,9 +437,10 @@ inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); } ...@@ -418,9 +437,10 @@ inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); }
inline float srelu(const float x) { return x > 0 ? x * x : 0; } inline float srelu(const float x) { return x > 0 ? x * x : 0; }
inline float dsrelu(const float x) { return fmaxf(0, 2 * x); } inline float dsrelu(const float x) { return fmaxf(0, 2 * x); }
size_t typeToSize(DType type); size_t typeToNumBits(DType type);
size_t product(const NVTEShape &shape); size_t product(const NVTEShape &shape);
size_t product(const std::vector<size_t> &shape); size_t product(const std::vector<size_t> &shape);
size_t bytes(const NVTEShape& shape, const DType type);
size_t first_dimension(const std::vector<size_t> &shape); size_t first_dimension(const std::vector<size_t> &shape);
size_t last_dimension(const std::vector<size_t> &shape); size_t last_dimension(const std::vector<size_t> &shape);
...@@ -466,6 +486,16 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -466,6 +486,16 @@ constexpr int32_t blackwellComputeCapability = 100;
} // namespace test } // namespace test
#if FP4_TYPE_SUPPORTED
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
case DType::kFloat4E2M1: { \
using type = fp4e2m1; \
{ __VA_ARGS__ } \
} break;
#else
#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing
#endif
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \ switch (dtype) { \
using namespace transformer_engine; \ using namespace transformer_engine; \
...@@ -517,8 +547,16 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -517,8 +547,16 @@ constexpr int32_t blackwellComputeCapability = 100;
{__VA_ARGS__} \ {__VA_ARGS__} \
} \ } \
break; \ break; \
case DType::kFloat8E8M0: \
{ \
using type = fp8e8m0; \
{__VA_ARGS__} \
} \
break; \
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \ default: \
NVTE_ERROR("Invalid type."); \ printf("dtype: %d\n", static_cast<int>(dtype)); \
NVTE_ERROR("Invalid type MARKED TEST."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
...@@ -537,7 +575,15 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -537,7 +575,15 @@ constexpr int32_t blackwellComputeCapability = 100;
} \ } \
break; \ break; \
default: \ default: \
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type MARKED TEST 2."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
SWITCH_FP4_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type MARKED TEST 3."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
...@@ -562,5 +608,5 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -562,5 +608,5 @@ constexpr int32_t blackwellComputeCapability = 100;
} \ } \
break; \ break; \
default: \ default: \
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type MARKED TEST 4."); \
} }
This diff is collapsed.
...@@ -68,6 +68,7 @@ class TestDistributedSelfAttn: ...@@ -68,6 +68,7 @@ class TestDistributedSelfAttn:
batch, seqlen, num_head, hidden = data_shape batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available( if not is_fused_attn_kernel_available(
is_training,
dtype, dtype,
dtype, dtype,
QKVLayout.BS3HD, QKVLayout.BS3HD,
...@@ -79,6 +80,7 @@ class TestDistributedSelfAttn: ...@@ -79,6 +80,7 @@ class TestDistributedSelfAttn:
seqlen, seqlen,
seqlen, seqlen,
hidden, hidden,
hidden,
None, # no window None, # no window
): ):
pytest.skip("No FusedAttn backend found") pytest.skip("No FusedAttn backend found")
...@@ -98,6 +100,7 @@ class TestDistributedSelfAttn: ...@@ -98,6 +100,7 @@ class TestDistributedSelfAttn:
num_head, num_head,
num_head, num_head,
hidden, hidden,
hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -214,6 +217,7 @@ class TestDistributedCrossAttn: ...@@ -214,6 +217,7 @@ class TestDistributedCrossAttn:
batch, seqlen, num_head, hidden = data_shape batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available( if not is_fused_attn_kernel_available(
is_training,
dtype, dtype,
dtype, dtype,
QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BS2HD,
...@@ -225,6 +229,7 @@ class TestDistributedCrossAttn: ...@@ -225,6 +229,7 @@ class TestDistributedCrossAttn:
seqlen, seqlen,
seqlen, seqlen,
hidden, hidden,
hidden,
None, # no window None, # no window
): ):
pytest.skip("No FusedAttn backend found") pytest.skip("No FusedAttn backend found")
...@@ -237,6 +242,7 @@ class TestDistributedCrossAttn: ...@@ -237,6 +242,7 @@ class TestDistributedCrossAttn:
num_head, num_head,
num_head, num_head,
hidden, hidden,
hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -289,6 +295,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -289,6 +295,7 @@ class TestDistributedContextParallelSelfAttn:
cp_strategy, cp_strategy,
use_shardy, use_shardy,
use_scan_ring=False, use_scan_ring=False,
window_size=None,
): ):
if qkv_layout.is_thd(): if qkv_layout.is_thd():
if cp_strategy == CPStrategy.ALL_GATHER: if cp_strategy == CPStrategy.ALL_GATHER:
...@@ -326,6 +333,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -326,6 +333,7 @@ class TestDistributedContextParallelSelfAttn:
num_head, num_head,
num_kv_heads, num_kv_heads,
hidden, hidden,
hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -333,7 +341,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -333,7 +341,7 @@ class TestDistributedContextParallelSelfAttn:
is_training, is_training,
qkv_layout, qkv_layout,
bias_shape, bias_shape,
None, window_size,
SeqDescFormat.SegmentIDs, SeqDescFormat.SegmentIDs,
number_of_devices=device_count, number_of_devices=device_count,
mesh_shape=mesh_shape, mesh_shape=mesh_shape,
...@@ -345,6 +353,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -345,6 +353,7 @@ class TestDistributedContextParallelSelfAttn:
def check_has_backend_for_mask(mask_type): def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available( return is_fused_attn_kernel_available(
is_training,
dtype, dtype,
dtype, dtype,
qkv_layout, qkv_layout,
...@@ -356,6 +365,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -356,6 +365,7 @@ class TestDistributedContextParallelSelfAttn:
seqlen, seqlen,
seqlen, seqlen,
hidden, hidden,
hidden,
None, None,
) # no SWA for CP ) # no SWA for CP
...@@ -476,6 +486,13 @@ class TestDistributedContextParallelSelfAttn: ...@@ -476,6 +486,13 @@ class TestDistributedContextParallelSelfAttn:
"use_scan", "use_scan",
[pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")], [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
) )
@pytest.mark.parametrize(
"window_size",
[
pytest.param((-1, -1), id="window_size(-1, -1)"),
pytest.param((20, 0), id="window_size(20, 0)"),
],
)
def test_context_parallel_ring_attn( def test_context_parallel_ring_attn(
self, self,
device_count, device_count,
...@@ -489,7 +506,15 @@ class TestDistributedContextParallelSelfAttn: ...@@ -489,7 +506,15 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
load_balanced, load_balanced,
use_scan, use_scan,
window_size,
): ):
if window_size != (-1, -1) and not qkv_layout.is_thd():
pytest.skip("Sliding window attention is only supported for THD layout")
if window_size != (-1, -1) and qkv_layout.is_thd() and use_scan:
pytest.skip(
"When context parallelism and sliding window attention are used, "
"scanloop is not supported"
)
self.impl_test_context_parallel_attn( self.impl_test_context_parallel_attn(
device_count, device_count,
mesh_shape, mesh_shape,
...@@ -504,6 +529,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -504,6 +529,7 @@ class TestDistributedContextParallelSelfAttn:
CPStrategy.RING, CPStrategy.RING,
use_shardy=False, use_shardy=False,
use_scan_ring=use_scan, use_scan_ring=use_scan,
window_size=window_size,
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -106,7 +106,8 @@ def general_dot_product_attention( ...@@ -106,7 +106,8 @@ def general_dot_product_attention(
softmax_out = softmax_out * multiplier softmax_out = softmax_out * multiplier
context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value) context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
context = jnp.reshape(context, query.shape) context_shape = query.shape[:-1] + (value.shape[-1],)
context = jnp.reshape(context, context_shape)
return context return context
...@@ -294,7 +295,8 @@ class FusedAttnRunner: ...@@ -294,7 +295,8 @@ class FusedAttnRunner:
max_seqlen_kv: int max_seqlen_kv: int
num_heads_q: int num_heads_q: int
num_heads_kv: int num_heads_kv: int
head_dim: int head_dim_qk: int
head_dim_v: int
attn_bias_type: AttnBiasType attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType attn_mask_type: AttnMaskType
dropout_prob: float dropout_prob: float
...@@ -346,7 +348,16 @@ class FusedAttnRunner: ...@@ -346,7 +348,16 @@ class FusedAttnRunner:
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
) )
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
pytest.skip(
"For head_dim_qk != head_dim_v, it is necessary that the QKV layout "
"is either BSHD_BSHD_BSHD or THD_THD_THD"
)
self.backend = FusedAttnHelper( self.backend = FusedAttnHelper(
self.is_training,
self.dtype, self.dtype,
self.dtype, self.dtype,
self.qkv_layout, self.qkv_layout,
...@@ -357,7 +368,8 @@ class FusedAttnRunner: ...@@ -357,7 +368,8 @@ class FusedAttnRunner:
self.num_heads_kv, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_q,
self.max_seqlen_kv, self.max_seqlen_kv,
self.head_dim, self.head_dim_qk,
self.head_dim_v,
(-1, -1) if self.window_size is None else self.window_size, (-1, -1) if self.window_size is None else self.window_size,
).get_fused_attn_backend() ).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
...@@ -390,13 +402,9 @@ class FusedAttnRunner: ...@@ -390,13 +402,9 @@ class FusedAttnRunner:
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim) q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
k_shape = v_shape = ( k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
self.batch_size, v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_v)
self.max_seqlen_kv,
self.num_heads_kv,
self.head_dim,
)
if self.attn_bias_type == AttnBiasType.NO_BIAS: if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None bias_shape = None
...@@ -615,7 +623,7 @@ class FusedAttnRunner: ...@@ -615,7 +623,7 @@ class FusedAttnRunner:
raise ValueError(f"Unknown {self.seq_desc_format=}") raise ValueError(f"Unknown {self.seq_desc_format=}")
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim) self.scaling_factor = 1.0 / sqrt(self.head_dim_qk)
# Setup distributed sharding specs # Setup distributed sharding specs
# Setup shardings for distributed tests # Setup shardings for distributed tests
...@@ -934,9 +942,31 @@ class FusedAttnRunner: ...@@ -934,9 +942,31 @@ class FusedAttnRunner:
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d, dtype", "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype",
[ [
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"), pytest.param(
2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF"
),
pytest.param(
2,
2048,
1024,
12,
12,
64,
64,
jnp.bfloat16,
id="2-2048-1024-12-12-64-64-BF16-CROSS",
),
pytest.param(
2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
),
pytest.param(
4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF"
),
pytest.param(
4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF"
),
pytest.param( pytest.param(
2, 2,
2048, 2048,
...@@ -944,11 +974,13 @@ class FusedAttnRunner: ...@@ -944,11 +974,13 @@ class FusedAttnRunner:
12, 12,
12, 12,
64, 64,
32,
jnp.bfloat16, jnp.bfloat16,
id="2-2048-1024-12-12-64-BF16-CROSS", id="2-2048-1024-12-12-64-32-BF16-CROSS",
),
pytest.param(
2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
), ),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1002,7 +1034,8 @@ class TestFusedAttn: ...@@ -1002,7 +1034,8 @@ class TestFusedAttn:
s_kv, s_kv,
h_q, h_q,
h_kv, h_kv,
d, d_qk,
d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -1027,7 +1060,8 @@ class TestFusedAttn: ...@@ -1027,7 +1060,8 @@ class TestFusedAttn:
s_kv, s_kv,
h_q, h_q,
h_kv, h_kv,
d, d_qk,
d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -1054,7 +1088,8 @@ class TestFusedAttn: ...@@ -1054,7 +1088,8 @@ class TestFusedAttn:
s_kv, s_kv,
h_q, h_q,
h_kv, h_kv,
d, d_qk,
d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -1076,7 +1111,8 @@ class TestFusedAttn: ...@@ -1076,7 +1111,8 @@ class TestFusedAttn:
s_kv, s_kv,
h_q, h_q,
h_kv, h_kv,
d, d_qk,
d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
def pytest_addoption(parser):
parser.addoption(
"--feature_dirs", nargs="+", action="store", default="", help="List of feature directories"
)
parser.addoption(
"--configs_dir",
action="store",
default="",
type=str,
help="Path to the directory with configs.",
)
@pytest.fixture
def feature_dirs(request):
return request.config.getoption("--feature_dirs")
@pytest.fixture
def configs_dir(request):
return request.config.getoption("--configs_dir")
This diff is collapsed.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
import nvdlfw_inspect.api as debug_api
try:
import transformer_engine
import transformer_engine_torch as tex
except (ImportError, ModuleNotFoundError):
print("Could not find TransformerEngine package.")
exit(1)
def test_transformer_engine_no_config(feature_dirs):
debug_api.initialize("", feature_dirs=feature_dirs)
try:
tensor = torch.rand(24, 2046).cuda()
# FP8 enabled - true by the default
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
# modify_tensor_enabled - False by default
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
# inspect_tensor_enabled - False by default
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.attn.qkv", tensor_name="activation", iteration=0
)
# inspect_tensor_postquantize - False by default
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
finally:
debug_api.end_debug()
def test_disable_fp8_gemm(configs_dir, feature_dirs):
try:
debug_api.initialize(configs_dir + "disable_fp8_gemms.yaml", feature_dirs=feature_dirs)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
# caching
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
finally:
debug_api.end_debug()
def test_disable_fp8_layer(configs_dir, feature_dirs):
try:
debug_api.initialize(configs_dir + "disable_fp8_layer.yaml", feature_dirs=feature_dirs)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="fprop", iteration=0
)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", iteration=0
)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
finally:
debug_api.end_debug()
def test_per_tensor_scaling(configs_dir, feature_dirs):
try:
debug_api.initialize(configs_dir + "per_tensor_scaling.yaml", feature_dirs=feature_dirs)
tensor = torch.rand(24, 2046).cuda()
# check modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
)
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0
)
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0
)
# check modify_tensor
default_quantizer1 = Float8Quantizer(
scale=torch.tensor([1]).cuda(),
amax=torch.tensor([0]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
default_quantizer2 = Float8Quantizer(
scale=torch.tensor([1]).cuda(),
amax=torch.tensor([0]).cuda(),
fp8_dtype=tex.DType.kFloat8E5M2,
)
output1 = debug_api.transformer_engine.modify_tensor(
layer_name="decoder.1.mlp.fc1",
gemm="fprop",
tensor_name="activation",
default_quantizer=default_quantizer1,
iteration=0,
tensor=tensor,
)
assert type(output1) == Float8Tensor
assert output1._fp8_dtype == tex.DType.kFloat8E4M3
output2 = debug_api.transformer_engine.modify_tensor(
"decoder.1.mlp.fc1",
gemm="dgrad",
tensor=tensor,
tensor_name="gradient",
default_quantizer=default_quantizer2,
iteration=0,
)
assert type(output2) == Float8Tensor
assert output2._fp8_dtype == tex.DType.kFloat8E5M2
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1",
gemm="wgrad",
tensor_name="gradient",
iteration=0,
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc4",
gemm="fprop",
tensor_name="activation",
iteration=0,
)
finally:
debug_api.end_debug()
def test_fake_quant(configs_dir, feature_dirs):
try:
debug_api.initialize(
configs_dir + "fake_quantization_config.yaml", feature_dirs=feature_dirs
)
tensor = torch.rand(24, 2046).cuda()
# modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
)
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
)
# modify_tensor
debug_api.transformer_engine.modify_tensor(
"decoder.1.mlp.fc1",
gemm="fprop",
tensor=tensor,
tensor_name="activation",
iteration=0,
default_quantizer=None,
)
debug_api.transformer_engine.modify_tensor(
"decoder.1.mlp.fc1",
gemm="dgrad",
tensor=tensor,
tensor_name="gradient",
iteration=0,
default_quantizer=None,
)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0
)
# caching
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0
)
finally:
debug_api.end_debug()
def test_statistics_collection(configs_dir, feature_dirs):
try:
debug_api.initialize(
config_file=configs_dir + "stats_collection_test_config.yaml",
feature_dirs=feature_dirs,
default_logging_enabled=False,
)
tensor = torch.randn((100, 100, 5)).cuda()
tensor_fp8 = Float8Tensor(
data=tensor.to(torch.uint8).cuda(),
fp8_scale_inv=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=tensor.shape,
dtype=torch.float32,
)
def log():
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
return STATS_BUFFERS.log_stats()
def assert_empty():
stats = log()
assert len(stats) == 0
# TE tensor stats --
debug_api.transformer_engine.inspect_tensor(
"decoder.1.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=200,
tp_group=None,
)
stats = log()
assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max()
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", iteration=201
)
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="activation", iteration=200
)
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)
expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5)
expected_overflows = (tensor_fp8._data == 126).sum() * 100 / (100 * 100 * 5)
# TE FP8 tensor stats --
assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
)
debug_api.transformer_engine.inspect_tensor_postquantize(
"decoder.1.mlp.fc1",
tensor=tensor_fp8,
tensor_name="gradient",
iteration=200,
rowwise=True,
tp_group=None,
)
stats = log()
torch.testing.assert_close(
stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows
)
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201
)
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
)
# Second config in same yaml
tensor = torch.rand((100, 100, 5))
debug_api.transformer_engine.inspect_tensor(
"decoder.6.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=200,
tp_group=None,
)
stats = log()
stats_names = [x[3] for x in stats.keys()]
all(s in stats_names for s in ["cur_amax", "dynamic_range", "mean", "std", "l1_norm"])
assert stats[("decoder.6.mlp.fc1", "activation", "mean", 200)] == tensor.mean()
debug_api.transformer_engine.inspect_tensor(
"decoder.7.mlp.fc1",
tensor=tensor,
tensor_name="weight",
iteration=200,
tp_group=None,
)
stats = log()
stats_names = [x[3] for x in stats.keys()]
all(s in stats_names for s in ["mean", "std", "l1_norm", "min", "max"])
assert stats[("decoder.7.mlp.fc1", "weight", "max", 200)] == tensor.max()
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.7.mlp.fc1", tensor_name="weight", iteration=201
)
assert_empty()
finally:
debug_api.end_debug()
def test_statistics_multi_run(configs_dir, feature_dirs):
try:
debug_api.initialize(
config_file=configs_dir + "stats_collection_test_config.yaml",
feature_dirs=feature_dirs,
default_logging_enabled=False,
)
def feed(tensor, tensor_fp8):
debug_api.transformer_engine.inspect_tensor(
"decoder.5.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=1,
tp_group=None,
)
debug_api.transformer_engine.inspect_tensor_postquantize(
"decoder.5.mlp.fc1",
tensor=tensor_fp8,
tensor_name="activation",
iteration=1,
rowwise=True,
tp_group=None,
)
def log_stats():
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
return STATS_BUFFERS.log_stats()
def fp8_tensor(t):
return Float8Tensor(
data=t.to(torch.uint8).cuda(),
fp8_scale_inv=torch.ones([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=t.shape,
dtype=torch.float32,
)
shape = [1024, 1024]
tensors = [torch.randn(shape) for _ in range(2)]
tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)]
feed(tensors[0], tensors_fp8[0])
feed(tensors[1], tensors_fp8[1])
stats1 = log_stats()
tensor2 = torch.cat((tensors[0], tensors[1])).cuda()
fp8tensor2 = fp8_tensor(tensor2)
feed(tensor2, fp8tensor2)
stats2 = log_stats()
assert len(stats1.keys()) > 0
for k in stats1.keys():
torch.testing.assert_close(stats1[k], stats2[k])
finally:
debug_api.end_debug()
if __name__ == "__main__":
pass
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib, os
from nvdlfw_inspect.config_manager import ConfigManager
import nvdlfw_inspect.api as debug_api
try:
import transformer_engine
from transformer_engine.debug.features.api import TEConfigAPIMapper
except (ImportError, ModuleNotFoundError):
print("Could not find TransformerEngine debug module.")
exit(1)
def test_transformer_engine_config_parsing(feature_dirs):
debug_api.initialize(
config_file=pathlib.Path(__file__).resolve().parent
/ "test_configs/tensor_manipulation_transformer_engine.yaml",
feature_dirs=feature_dirs,
log_dir="./log",
)
cfg_fc1 = ConfigManager.get_config_for_layer("decoder.1.mlp.fc1")["transformer_engine"]
cfg_fc2 = ConfigManager.get_config_for_layer("decoder.1.mlp.fc2")["transformer_engine"]
assert cfg_fc1 and cfg_fc2
gemm_parsing = True
tensor_parsing = True
# Per tensor scaling set for dgrad, filter based on gemm
ret, _ = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="wgrad",
tensor_name="activation",
)
assert not ret
# per tensor scaling set for gradient, filter based on tensor name
ret, _ = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="activation",
)
assert not ret
ret, parsed_cfg_fc1 = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="gradient",
)
assert ret
assert parsed_cfg_fc1 == {"gemm": "dgrad", "tensor": "gradient"}
# Test tensor struct
ret, parsed_cfg_fc1_act = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="activation",
)
ret, parsed_cfg_fc1_wei = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="weight",
)
assert ret
assert parsed_cfg_fc1_act == {
"gemm": "fprop",
"tensor": "activation",
"quant_format": "FP8E4M3",
}
assert parsed_cfg_fc1_wei == {
"gemm": "fprop",
"tensor": "weight",
"quant_format": "FP8E4M3",
}
# Test gemms struct
ret, parsed_cfg_fc2_grad = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="gradient",
)
assert ret
assert parsed_cfg_fc2_grad == {"gemm": "dgrad", "tensor": "gradient", "quant_format": "FP8E5M2"}
ret, parsed_cfg_fc2_wei = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="weight",
)
assert ret
assert parsed_cfg_fc2_wei == {"gemm": "dgrad", "tensor": "weight", "quant_format": "FP8E5M2"}
# Test gemm + tensor struct
ret, parsed_cfg_fc2_fprop_act = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="activation",
)
assert ret
assert parsed_cfg_fc2_fprop_act == {"gemm": "fprop", "tensor": "activation"}
ret, parsed_cfg_fc2_fprop_wei = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="weight",
)
assert ret
assert parsed_cfg_fc2_fprop_wei == {"gemm": "fprop", "tensor": "weight"}
ret, parsed_cfg_fc2_wgrad_act = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="wgrad",
tensor_name="activation",
)
assert ret
assert parsed_cfg_fc2_wgrad_act == {"gemm": "wgrad", "tensor": "activation"}
ret, parsed_cfg_fc2_wgrad_grad = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="wgrad",
tensor_name="gradient",
)
assert ret
assert parsed_cfg_fc2_wgrad_grad == {"gemm": "wgrad", "tensor": "gradient"}
ConfigManager.reset()
test_disable_fp8_gemm_1:
enabled: True
layers:
layer_types: [qkv, fc2]
transformer_engine:
DisableFP8GEMM:
enabled: True
gemms: [dgrad, wgrad]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment