Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
...@@ -24,10 +24,10 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa ...@@ -24,10 +24,10 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
# Make encoder tests to have run-to-run deterministic to have the stable CI results # Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" # python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
wait # wait
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" # python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
wait # wait
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" . $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
......
...@@ -27,9 +27,6 @@ mkdir -p "$XML_LOG_DIR" ...@@ -27,9 +27,6 @@ mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*"
# Test without custom calls
NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
...@@ -37,6 +34,9 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa ...@@ -37,6 +34,9 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
# Make encoder tests to have run-to-run deterministic to have the stable CI results # Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES" echo "Error: some sub-tests failed: $FAILED_CASES"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
: ${TE_PATH:=/opt/transformerengine}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
: ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/}
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
FAIL=0
pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
exit $FAIL
...@@ -20,5 +20,5 @@ if [ -z "${CPP_ONLY}" ] ...@@ -20,5 +20,5 @@ if [ -z "${CPP_ONLY}" ]
then then
cd $TE_PATH cd $TE_PATH
echo "Checking Python files" echo "Checking Python files"
python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch transformer_engine/debug
fi fi
...@@ -47,6 +47,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entro ...@@ -47,6 +47,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entro
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
......
...@@ -20,6 +20,7 @@ FAILED_CASES="" ...@@ -20,6 +20,7 @@ FAILED_CASES=""
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
...@@ -30,6 +31,19 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use ...@@ -30,6 +31,19 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
# debug tests
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
exit 1 exit 1
......
...@@ -25,18 +25,18 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" ...@@ -25,18 +25,18 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
# Test without custom calls
NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements" pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results # Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS_RE="" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES" echo "Error: some sub-tests failed: $FAILED_CASES"
......
...@@ -19,11 +19,7 @@ from build_tools.te_version import te_version ...@@ -19,11 +19,7 @@ from build_tools.te_version import te_version
from build_tools.utils import ( from build_tools.utils import (
rocm_build, rocm_build,
cuda_archs, cuda_archs,
found_cmake,
found_ninja,
found_pybind11,
get_frameworks, get_frameworks,
install_and_import,
remove_dups, remove_dups,
) )
...@@ -38,7 +34,6 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1" ...@@ -38,7 +34,6 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1"
if "pytorch" in frameworks: if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks: elif "jax" in frameworks:
install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension from pybind11.setup_helpers import build_ext as BuildExtension
...@@ -87,6 +82,11 @@ def setup_common_extension() -> CMakeExtension: ...@@ -87,6 +82,11 @@ def setup_common_extension() -> CMakeExtension:
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))): if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON") cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")
# Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
if nvte_cmake_extra_args:
cmake_flags.extend(nvte_cmake_extra_args.split())
# Project directory root # Project directory root
root_path = Path(__file__).resolve().parent root_path = Path(__file__).resolve().parent
if rocm_build(): if rocm_build():
...@@ -102,22 +102,13 @@ def setup_common_extension() -> CMakeExtension: ...@@ -102,22 +102,13 @@ def setup_common_extension() -> CMakeExtension:
) )
def setup_requirements() -> Tuple[List[str], List[str], List[str]]: def setup_requirements() -> Tuple[List[str], List[str]]:
"""Setup Python dependencies """Setup Python dependencies
Returns dependencies for build, runtime, and testing. Returns dependencies for runtime and testing.
""" """
# Common requirements # Common requirements
setup_reqs: List[str] = [
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
]
install_reqs: List[str] = [ install_reqs: List[str] = [
"pydantic", "pydantic",
"importlib-metadata>=1.0", "importlib-metadata>=1.0",
...@@ -125,32 +116,20 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -125,32 +116,20 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
] ]
test_reqs: List[str] = ["pytest>=8.2.1"] test_reqs: List[str] = ["pytest>=8.2.1"]
# Requirements that may be installed outside of Python
if not found_cmake():
setup_reqs.append("cmake>=3.21")
if not found_ninja():
setup_reqs.append("ninja")
if not found_pybind11():
setup_reqs.append("pybind11")
# Framework-specific requirements # Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
setup_reqs.extend(["torch>=2.1"]) from build_tools.pytorch import install_requirements, test_requirements
install_reqs.extend(["torch>=2.1"])
# install_reqs.append( install_reqs.extend(install_requirements())
# "nvdlfw-inspect @" test_reqs.extend(test_requirements())
# " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
# )
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision"])
if "jax" in frameworks: if "jax" in frameworks:
setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"]) from build_tools.jax import install_requirements, test_requirements
install_reqs.extend(["jax", "flax>=0.7.1"])
test_reqs.extend(["numpy"]) install_reqs.extend(install_requirements())
test_reqs.extend(test_requirements())
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
if __name__ == "__main__": if __name__ == "__main__":
...@@ -167,14 +146,13 @@ if __name__ == "__main__": ...@@ -167,14 +146,13 @@ if __name__ == "__main__":
ext_modules = [] ext_modules = []
package_data = {} package_data = {}
include_package_data = False include_package_data = False
setup_requires = []
install_requires = ([f"transformer_engine_cu12=={__version__}"],) install_requires = ([f"transformer_engine_cu12=={__version__}"],)
extras_require = { extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"],
} }
else: else:
setup_requires, install_requires, test_requires = setup_requirements() install_requires, test_requires = setup_requirements()
ext_modules = [setup_common_extension()] ext_modules = [setup_common_extension()]
package_data = {"": ["VERSION.txt"]} package_data = {"": ["VERSION.txt"]}
include_package_data = True include_package_data = True
...@@ -219,15 +197,8 @@ if __name__ == "__main__": ...@@ -219,15 +197,8 @@ if __name__ == "__main__":
long_description_content_type="text/x-rst", long_description_content_type="text/x-rst",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=">=3.8, <3.13", python_requires=">=3.8",
classifiers=[ classifiers=["Programming Language :: Python :: 3"],
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
setup_requires=setup_requires,
install_requires=install_requires, install_requires=install_requires,
license_files=("LICENSE",), license_files=("LICENSE",),
include_package_data=include_package_data, include_package_data=include_package_data,
......
...@@ -375,7 +375,7 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = { ...@@ -375,7 +375,7 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = {
{256, 256}, {256, 256},
{993, 512}, {993, 512},
{768, 1024}, {768, 1024},
{65536, 128}, {65504, 128},
{16384, 1632}, {16384, 1632},
}; };
......
...@@ -71,7 +71,8 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const ...@@ -71,7 +71,8 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
// Remove the use_cudnn check here when it is supported by both backends. // Remove the use_cudnn check here when it is supported by both backends.
const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){ if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3> ||
std::is_same_v<InputType, fp4e2m1>){
compute_t g = static_cast<compute_t>(gamma); compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) { if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f); g += static_cast<compute_t>(1.f);
......
...@@ -45,7 +45,7 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) { ...@@ -45,7 +45,7 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
return true; return true;
} }
size_t typeToSize(DType type) { size_t typeToNumBits(DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
{ {
return TypeInfo<T>::size; return TypeInfo<T>::size;
...@@ -62,7 +62,8 @@ const std::string &typeName(DType type) { ...@@ -62,7 +62,8 @@ const std::string &typeName(DType type) {
{DType::kBFloat16, "bfloat16"}, {DType::kBFloat16, "bfloat16"},
{DType::kFloat8E4M3, "float8e4m3"}, {DType::kFloat8E4M3, "float8e4m3"},
{DType::kFloat8E5M2, "float8e5m2"}, {DType::kFloat8E5M2, "float8e5m2"},
{DType::kFloat8E8M0, "float8e8m0"}}; {DType::kFloat8E8M0, "float8e8m0"},
{DType::kFloat4E2M1, "float4e2m1"}};
return name_map.at(type); return name_map.at(type);
} }
...@@ -109,9 +110,16 @@ size_t DIVUP(const size_t &x, const size_t &y){ ...@@ -109,9 +110,16 @@ size_t DIVUP(const size_t &x, const size_t &y){
struct scale_inv_meta { struct scale_inv_meta {
std::vector<size_t> shape; std::vector<size_t> shape;
DType type; DType type;
size_t type_size; size_t type_size_bits;
size_t bytes() const noexcept {
return (product(shape) * type_size_bits) / 8;
}
}; };
size_t bytes(const NVTEShape& shape, const DType type) {
return (product(shape) * typeToNumBits(type)) / 8;
}
NVTEShape convertShape(const std::vector<size_t>& s) { NVTEShape convertShape(const std::vector<size_t>& s) {
return nvte_make_shape(s.data(), s.size()); return nvte_make_shape(s.data(), s.size());
} }
...@@ -122,7 +130,7 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -122,7 +130,7 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret; scale_inv_meta ret;
ret.shape = {1}; ret.shape = {1};
ret.type = DType::kFloat32; ret.type = DType::kFloat32;
ret.type_size = sizeof(float); ret.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret, ret}; return {ret, ret};
} }
if (scaling_mode == NVTE_MXFP8_1D_SCALING) { if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
...@@ -152,8 +160,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -152,8 +160,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
} }
ret_rowwise.type = DType::kFloat8E8M0; ret_rowwise.type = DType::kFloat8E8M0;
ret_colwise.type = DType::kFloat8E8M0; ret_colwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size = sizeof(uint8_t); ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
ret_colwise.type_size = sizeof(uint8_t); ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
return {ret_rowwise, ret_colwise}; return {ret_rowwise, ret_colwise};
} }
...@@ -179,8 +187,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -179,8 +187,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
} }
ret_rowwise.type = DType::kFloat32; ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32; ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float); ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
ret_colwise.type_size = sizeof(float); ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret_rowwise, ret_colwise}; return {ret_rowwise, ret_colwise};
} }
...@@ -205,8 +213,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -205,8 +213,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
} }
ret_rowwise.type = DType::kFloat32; ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32; ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float); ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
ret_colwise.type_size = sizeof(float); ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret_rowwise, ret_colwise}; return {ret_rowwise, ret_colwise};
} }
...@@ -222,8 +230,7 @@ Tensor::Tensor(const std::string& name, ...@@ -222,8 +230,7 @@ Tensor::Tensor(const std::string& name,
gen_.seed(seed); gen_.seed(seed);
rowwise_ = rowwise; rowwise_ = rowwise;
columnwise_ = columnwise; columnwise_ = columnwise;
size_t s = typeToSize(type); size_t total_size = bytes(shape, type);
size_t total_size = product(shape) * s;
void *dptr_rowwise = nullptr; void *dptr_rowwise = nullptr;
void *dptr_columnwise = nullptr; void *dptr_columnwise = nullptr;
cpu_data_rowwise_ = nullptr; cpu_data_rowwise_ = nullptr;
...@@ -305,8 +312,8 @@ Tensor::Tensor(const std::string& name, ...@@ -305,8 +312,8 @@ Tensor::Tensor(const std::string& name,
} else { } else {
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(normalized_shape, tensor_.scaling_mode()); get_scales(normalized_shape, tensor_.scaling_mode());
auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto rowwise_scale_size = rowwise_scale_meta.bytes();
auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto columnwise_scale_size = colwise_scale_meta.bytes();
auto scale_shape = rowwise_scale_meta.shape; auto scale_shape = rowwise_scale_meta.shape;
auto columnwise_scale_shape = colwise_scale_meta.shape; auto columnwise_scale_shape = colwise_scale_meta.shape;
if (rowwise) { if (rowwise) {
...@@ -331,7 +338,7 @@ Tensor::Tensor(const std::string& name, ...@@ -331,7 +338,7 @@ Tensor::Tensor(const std::string& name,
void Tensor::to_cpu() const { void Tensor::to_cpu() const {
const NVTEShape s = tensor_.shape(); const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype()); const size_t size = bytes(s, tensor_.dtype());
if (rowwise_) { if (rowwise_) {
cudaMemcpy(cpu_data_rowwise_.get(), cudaMemcpy(cpu_data_rowwise_.get(),
tensor_.get_rowwise_data().data_ptr, tensor_.get_rowwise_data().data_ptr,
...@@ -360,14 +367,14 @@ void Tensor::to_cpu() const { ...@@ -360,14 +367,14 @@ void Tensor::to_cpu() const {
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode()); get_scales(s, tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
tensor_.get_rowwise_scale_inv().data_ptr, tensor_.get_rowwise_scale_inv().data_ptr,
scale_size, scale_size,
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
} }
if (columnwise_) { if (columnwise_) {
auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto scale_size = colwise_scale_meta.bytes();
cudaMemcpy(columnwise_scale_inv_cpu_data_.get(), cudaMemcpy(columnwise_scale_inv_cpu_data_.get(),
tensor_.get_columnwise_scale_inv().data_ptr, tensor_.get_columnwise_scale_inv().data_ptr,
scale_size, scale_size,
...@@ -378,34 +385,32 @@ void Tensor::to_cpu() const { ...@@ -378,34 +385,32 @@ void Tensor::to_cpu() const {
void Tensor::from_cpu() const { void Tensor::from_cpu() const {
const NVTEShape s = tensor_.shape(); const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype()); const size_t size = bytes(s, tensor_.dtype());
if (rowwise_) { if (rowwise_) {
cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size,
cpu_data_rowwise_.get(), size, cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
if (columnwise_) { if (columnwise_) {
cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
if (isFp8Type(dtype())) { if (isFp8Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (tensor_.amax() != nullptr){ if (tensor_.amax() != nullptr){
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpyHostToDevice);
} }
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpyHostToDevice);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode()); get_scales(s, tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
rowwise_scale_inv_cpu_data_.get(), scale_size, rowwise_scale_inv_cpu_data_.get(), scale_size,
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
if (columnwise_) { if (columnwise_) {
auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto scale_size = colwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr, cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr,
columnwise_scale_inv_cpu_data_.get(), scale_size, columnwise_scale_inv_cpu_data_.get(), scale_size,
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
...@@ -735,6 +740,19 @@ std::pair<double, double> getTolerances(const DType type) { ...@@ -735,6 +740,19 @@ std::pair<double, double> getTolerances(const DType type) {
template <typename T> template <typename T>
void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
// Check how many RNG calls are required to generate one uniform random value
int rng_calls_per_val = 0;
{
std::mt19937 gen1 = *gen, gen2 = *gen;
std::uniform_real_distribution<> dis(-2.0, 1.0);
const float _ = dis(gen1);
while (gen2 != gen1) {
auto _ = gen2();
++rng_calls_per_val;
}
}
// Generate uniform random values in parallel
#pragma omp parallel proc_bind(spread) #pragma omp parallel proc_bind(spread)
{ {
std::mt19937 gen_local = *gen; std::mt19937 gen_local = *gen;
...@@ -743,7 +761,7 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { ...@@ -743,7 +761,7 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
const int chunk_size = (size + threads_num - 1) / threads_num; const int chunk_size = (size + threads_num - 1) / threads_num;
const int idx_min = chunk_size * thread_ID; const int idx_min = chunk_size * thread_ID;
const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size)); const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size));
gen_local.discard(idx_min); gen_local.discard(idx_min * rng_calls_per_val);
std::uniform_real_distribution<> dis(-2.0, 1.0); std::uniform_real_distribution<> dis(-2.0, 1.0);
for (int i = idx_min; i < idx_max; ++i) { for (int i = idx_min; i < idx_max; ++i) {
...@@ -754,7 +772,7 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { ...@@ -754,7 +772,7 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
#endif #endif
} }
} }
gen->discard(size); gen->discard(size * rng_calls_per_val);
} }
void fillUniform(Tensor *t) { void fillUniform(Tensor *t) {
......
...@@ -10,11 +10,18 @@ ...@@ -10,11 +10,18 @@
#include <vector> #include <vector>
#include <array> #include <array>
#include <random> #include <random>
#include <cudaTypedefs.h>
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_fp8.h>
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif
#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include "util/logging.h" #include "util/logging.h"
...@@ -56,20 +63,32 @@ using fp8e4m3 = __nv_fp8_e4m3; ...@@ -56,20 +63,32 @@ using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
using fp8e8m0 = uint8_t; using fp8e8m0 = uint8_t;
using int8 = int8_t; using int8 = int8_t;
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
#endif
template <typename T>
struct BitsNumber;
#if FP4_TYPE_SUPPORTED
template <>
struct BitsNumber<fp4e2m1> {
static constexpr size_t num_bits = 4;
};
#endif
template <typename T>
struct BitsNumber {
static constexpr size_t num_bits = 8 * sizeof(T);
};
template <typename T> template <typename T>
struct TypeInfo{ struct TypeInfo {
using types = std::tuple<byte, #if FP4_TYPE_SUPPORTED
int16, using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, fp4e2m1>;
int32, #else
int64, using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, int8>;
fp32, #endif
fp16,
bf16,
fp8e4m3,
fp8e5m2,
fp8e8m0,
int8>;
template <typename U, DType current> template <typename U, DType current>
struct Helper { struct Helper {
...@@ -96,7 +115,7 @@ struct TypeInfo{ ...@@ -96,7 +115,7 @@ struct TypeInfo{
} }
constexpr static DType dtype = getType<T>(); constexpr static DType dtype = getType<T>();
constexpr static size_t size = sizeof(T); constexpr static size_t size = BitsNumber<T>::num_bits;;
}; };
class Tensor { class Tensor {
...@@ -418,9 +437,10 @@ inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); } ...@@ -418,9 +437,10 @@ inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); }
inline float srelu(const float x) { return x > 0 ? x * x : 0; } inline float srelu(const float x) { return x > 0 ? x * x : 0; }
inline float dsrelu(const float x) { return fmaxf(0, 2 * x); } inline float dsrelu(const float x) { return fmaxf(0, 2 * x); }
size_t typeToSize(DType type); size_t typeToNumBits(DType type);
size_t product(const NVTEShape &shape); size_t product(const NVTEShape &shape);
size_t product(const std::vector<size_t> &shape); size_t product(const std::vector<size_t> &shape);
size_t bytes(const NVTEShape& shape, const DType type);
size_t first_dimension(const std::vector<size_t> &shape); size_t first_dimension(const std::vector<size_t> &shape);
size_t last_dimension(const std::vector<size_t> &shape); size_t last_dimension(const std::vector<size_t> &shape);
...@@ -466,6 +486,16 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -466,6 +486,16 @@ constexpr int32_t blackwellComputeCapability = 100;
} // namespace test } // namespace test
#if FP4_TYPE_SUPPORTED
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
case DType::kFloat4E2M1: { \
using type = fp4e2m1; \
{ __VA_ARGS__ } \
} break;
#else
#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing
#endif
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \ switch (dtype) { \
using namespace transformer_engine; \ using namespace transformer_engine; \
...@@ -517,8 +547,16 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -517,8 +547,16 @@ constexpr int32_t blackwellComputeCapability = 100;
{__VA_ARGS__} \ {__VA_ARGS__} \
} \ } \
break; \ break; \
case DType::kFloat8E8M0: \
{ \
using type = fp8e8m0; \
{__VA_ARGS__} \
} \
break; \
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \ default: \
NVTE_ERROR("Invalid type."); \ printf("dtype: %d\n", static_cast<int>(dtype)); \
NVTE_ERROR("Invalid type MARKED TEST."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
...@@ -537,7 +575,15 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -537,7 +575,15 @@ constexpr int32_t blackwellComputeCapability = 100;
} \ } \
break; \ break; \
default: \ default: \
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type MARKED TEST 2."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
SWITCH_FP4_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type MARKED TEST 3."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
...@@ -562,5 +608,5 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -562,5 +608,5 @@ constexpr int32_t blackwellComputeCapability = 100;
} \ } \
break; \ break; \
default: \ default: \
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type MARKED TEST 4."); \
} }
...@@ -4,15 +4,14 @@ ...@@ -4,15 +4,14 @@
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
import pytest import pytest
from jax import jit, value_and_grad from jax import jit, value_and_grad
from functools import reduce from functools import reduce
from typing import Union
import operator import operator
from utils import ( from utils import (
assert_allclose, assert_allclose,
assert_tree_like_allclose,
pytest_parametrize_wrapper, pytest_parametrize_wrapper,
) )
from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm import layernorm
...@@ -33,15 +32,18 @@ from transformer_engine.jax import cpp_extensions as tex ...@@ -33,15 +32,18 @@ from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
DelayedScaleQuantizer, DelayedScaleQuantizer,
ScaledTensor, ScaledTensor,
ScaledTensor1x,
ScaledTensor2x,
GroupedScaledTensor1x,
ScalingMode, ScalingMode,
QuantizerFactory, QuantizerFactory,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set,
) )
from transformer_engine.jax.quantize import helper from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x
GEMM_CASES = [ GEMM_CASES = [
(256, 256, 512), (256, 256, 512),
...@@ -53,8 +55,8 @@ GEMM_CASES = [ ...@@ -53,8 +55,8 @@ GEMM_CASES = [
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2] FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
LN_CASES = [(256, 128), (128, 256)] LN_CASES = [(256, 128), (128, 256)]
DTYPES = [jnp.bfloat16, jnp.float32] DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = helper.is_fp8_available() is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available()
is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
supported_scaling_modes = [] supported_scaling_modes = []
""" Find supported scaling modes""" """ Find supported scaling modes"""
...@@ -113,6 +115,38 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray): ...@@ -113,6 +115,38 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
pytest.fail("a must be a ScaledTensor object") pytest.fail("a must be a ScaledTensor object")
def assert_dequantized_grouped_scaled_tensor(
a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray
):
if isinstance(a, GroupedScaledTensor1x):
assert a.group_sizes.sum() == b.shape[0]
b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0)
dq_a = a.dequantize()
for dq_a_i, b_i in zip(dq_a, b):
if len(dq_a_i) == 0:
continue
if a.data_layout == "T":
data_ndim = len(a.original_shape)
flatten_axis = a.flatten_axis
if b_i.shape[0] == 1:
b_i = jnp.transpose(
b_i, (0, *range(flatten_axis, data_ndim), *range(1, flatten_axis))
)
else:
b_i = jnp.transpose(
b_i, (*range(flatten_axis, data_ndim), *range(flatten_axis))
)
dq_a_i = dq_a_i.reshape(b_i.shape)
assert_allclose(dq_a_i, b_i, dtype=a.data.dtype)
elif isinstance(a, ScaledTensor2x):
assert isinstance(a.get_rowwise_tensor(), GroupedScaledTensor1x)
assert isinstance(a.get_colwise_tensor(), GroupedScaledTensor1x)
assert_dequantized_grouped_scaled_tensor(a.get_rowwise_tensor(), b)
assert_dequantized_grouped_scaled_tensor(a.get_colwise_tensor(), b)
else:
pytest.fail("a must be a GroupedScaledTensor object")
ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)] ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)]
ALL_ACTIVATION_TYPES = [ ALL_ACTIVATION_TYPES = [
("gelu",), ("gelu",),
...@@ -173,7 +207,7 @@ class TestActivation: ...@@ -173,7 +207,7 @@ class TestActivation:
assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
...@@ -204,7 +238,7 @@ class TestActivation: ...@@ -204,7 +238,7 @@ class TestActivation:
assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_out, ref_out, dtype=output_type)
assert_allclose(prim_grad, ref_grad, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type)
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
...@@ -234,7 +268,7 @@ class TestActivation: ...@@ -234,7 +268,7 @@ class TestActivation:
assert_bitwise_scaled_tensors(te_output, jax_output) assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)]) @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
...@@ -355,7 +389,7 @@ class TestNorm: ...@@ -355,7 +389,7 @@ class TestNorm:
n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
# No Norm FWD E5M2 in TE backend # No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
...@@ -470,7 +504,7 @@ class TestNorm: ...@@ -470,7 +504,7 @@ class TestNorm:
if norm_type == "layernorm": if norm_type == "layernorm":
assert_allclose(mu, ref_mu, dtype=inp_dtype) assert_allclose(mu, ref_mu, dtype=inp_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
# No Norm FWD E5M2 in TE backend # No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
...@@ -506,7 +540,7 @@ class TestNorm: ...@@ -506,7 +540,7 @@ class TestNorm:
q_layout=q_layout, q_layout=q_layout,
) )
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
def test_norm_forward_with_block_scaling_fp8( def test_norm_forward_with_block_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
...@@ -532,7 +566,7 @@ QUANTIZE_OUTPUT_DTYPES = { ...@@ -532,7 +566,7 @@ QUANTIZE_OUTPUT_DTYPES = {
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [ ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
((32, 64), -1), ((32, 64), -1),
((2, 64, 32), -1), ((2, 64, 32), -1),
((2, 64, 32), -2), ((64, 2, 32), -2),
((32, 256, 128), -1), ((32, 256, 128), -1),
((32, 256, 128), -2), ((32, 256, 128), -2),
((64, 32, 32, 256), -1), ((64, 32, 32, 256), -1),
...@@ -544,7 +578,7 @@ QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = { ...@@ -544,7 +578,7 @@ QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
"L0": [ "L0": [
((32, 64), -1), ((32, 64), -1),
((2, 64, 32), -1), ((2, 64, 32), -1),
((2, 64, 32), -2), ((64, 2, 32), -2),
], ],
"L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES, "L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES,
} }
...@@ -555,7 +589,7 @@ QUANTIZATION_INPUT_DTYPE = { ...@@ -555,7 +589,7 @@ QUANTIZATION_INPUT_DTYPE = {
} }
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
...@@ -577,9 +611,6 @@ class TestQuantize: ...@@ -577,9 +611,6 @@ class TestQuantize:
q_dtype=q_dtype, q_dtype=q_dtype,
q_layout=q_layout, q_layout=q_layout,
) )
# Adding dimension to test if padding is done correctly when flatten 3D to 2D
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations): for _ in range(n_iterations):
...@@ -593,8 +624,6 @@ class TestQuantize: ...@@ -593,8 +624,6 @@ class TestQuantize:
): ):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
input = jax.random.uniform(key, input_shape, in_dtype) input = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create( te_quantizer, jax_quantizer = QuantizerFactory.create(
...@@ -607,10 +636,65 @@ class TestQuantize: ...@@ -607,10 +636,65 @@ class TestQuantize:
assert_bitwise_scaled_tensors(te_output, jax_output) assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1])
@pytest_parametrize_wrapper("with_group_sizes", [True, False])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
class TestGroupedQuantize:
def test_grouped_qdq(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes
):
n_groups, m, n = input_shape
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
# *32 so that the input shapes works for MXFP8
input_shape = (m * 32, n)
if with_group_sizes:
group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
group_sizes = jnp.diff(group_sizes)
assert group_sizes.sum() == m
assert jnp.any(group_sizes == 0) # make sure that at least one group has 0 row
group_sizes = group_sizes * 32
else:
group_sizes = None
input_shape = (n_groups, input_shape[0] // n_groups, input_shape[1])
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
x = jax.random.uniform(subkeys[1], input_shape, in_dtype)
grouped_quantizer = QuantizerFactory.create(
scaling_mode=scaling_mode,
q_dtype=q_dtype,
q_layout=q_layout,
n_groups=n_groups,
)
# grouped_quantize does not work with cudaGraph yet, so the jitting will breaks
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
# disable cudaGraph, then use the following jitted function
scaled_tensor = tex.grouped_quantize(
x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
)
assert_dequantized_grouped_scaled_tensor(scaled_tensor, x)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
class TestFusedQuantize: class TestFusedQuantize:
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
...@@ -625,12 +709,6 @@ class TestFusedQuantize: ...@@ -625,12 +709,6 @@ class TestFusedQuantize:
): ):
pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")
if (flatten_axis < 0 and flatten_axis + len(input_shape) <= 0) or flatten_axis <= 0:
pytest.skip(
f"Flatten axis {flatten_axis} is not supported for input shape {input_shape}. There"
" must be at least one axis on either side of the flatten_axis split."
)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype) input = jax.random.uniform(key, input_shape, in_dtype)
...@@ -717,7 +795,7 @@ class TestFusedQuantize: ...@@ -717,7 +795,7 @@ class TestFusedQuantize:
q_layout=QuantizeLayout.ROWWISE, q_layout=QuantizeLayout.ROWWISE,
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
...@@ -741,7 +819,7 @@ class TestFusedQuantize: ...@@ -741,7 +819,7 @@ class TestFusedQuantize:
q_layout=q_layout, q_layout=q_layout,
) )
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
"input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)] "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)]
...@@ -810,7 +888,7 @@ class TestDense: ...@@ -810,7 +888,7 @@ class TestDense:
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
...@@ -852,7 +930,7 @@ class TestDense: ...@@ -852,7 +930,7 @@ class TestDense:
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16) assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
...@@ -916,7 +994,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan ...@@ -916,7 +994,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class TestFusedDense: class TestFusedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
...@@ -1001,7 +1079,7 @@ class TestFusedDense: ...@@ -1001,7 +1079,7 @@ class TestFusedDense:
if beta is not None: if beta is not None:
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
...@@ -1129,24 +1207,6 @@ class TestFusedDense: ...@@ -1129,24 +1207,6 @@ class TestFusedDense:
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)
# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm()
def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer):
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
lhs_q = lhs_quantizer.quantize(
lhs,
is_rowwise=lhs_is_rowwise,
is_colwise=not lhs_is_rowwise,
)
rhs_q = rhs_quantizer.quantize(
rhs,
is_rowwise=rhs_is_rowwise,
is_colwise=not rhs_is_rowwise,
)
return lhs_q, rhs_q
# E5M2 * E5M2 is not supported # E5M2 * E5M2 is not supported
fwd_bwd_dtypes = [ fwd_bwd_dtypes = [
[jnp.float8_e4m3fn, jnp.float8_e4m3fn], [jnp.float8_e4m3fn, jnp.float8_e4m3fn],
...@@ -1154,219 +1214,217 @@ fwd_bwd_dtypes = [ ...@@ -1154,219 +1214,217 @@ fwd_bwd_dtypes = [
[jnp.float8_e5m2, jnp.float8_e4m3fn], [jnp.float8_e5m2, jnp.float8_e4m3fn],
] ]
""" GROUPED_DENSE_INPUT_SHAPES = [
@pytest_parametrize_wrapper( # (n_groups, m, n, k), the actual m will be multiplied by 32
"shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]] (5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4
) (8, 64, 32, 128),
(8, 64, 128, 256),
]
@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES)
class TestGroupedDense: class TestGroupedDense:
def _ref_grouped_gemm_with_jnp_dot(self, lhs_list, rhs_list, contracting_dims_list): def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims):
ref_out_list = [] lhs_contract_dim, _ = contracting_dims
for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list): assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3
dim_nums = (contracting_dims, ((), ())) if bias is None:
ref_out_list.append(jax.lax.dot_general(lhs, rhs, dim_nums)) bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype)
return ref_out_list else:
assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2])
def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list): remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop()
lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis)
rhs = jnp.split(rhs, rhs.shape[0], axis=0)
bias = jnp.split(bias, bias.shape[0], axis=0)
ref_out = []
dim_num = (contracting_dims, ((), ()))
for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0)
ref_out.append(jnp.squeeze(out_i))
return ref_out
def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, len(shape_list) * 2) subkeys = jax.random.split(key, 4)
n_groups, m, n, k = input_shape
lhs_list, rhs_list, contracting_dims_list = [], [], []
for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)): group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
lhs = jax.random.uniform( group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
subkeys[2 * i], group_sizes = jnp.diff(group_sizes)
(m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m), assert group_sizes.sum() == m
dtype=dtype,
) # *32 to make sure that input shape works for MXFP8
rhs = jax.random.uniform( group_sizes = group_sizes * 32
subkeys[2 * i + 1], m = m * 32
(k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=dtype, lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m)
) rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k)
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) bias_shape = (n_groups, n)
rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
lhs_list.append(lhs) lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype)
rhs_list.append(rhs) rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype)
contracting_dims_list.append(contracting_dims) bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
return lhs_list, rhs_list, contracting_dims_list return lhs, rhs, group_sizes, contracting_dims, bias
def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype):
assert out.dtype == ref_list[0].dtype
out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
for i in range(len(ref_list)):
assert_allclose(out_list[i], ref_list[i], dtype=dtype)
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
@pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]]) @pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp16(self, dtype, shape_list, layout_list): def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input( lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
dtype, shape_list, layout_list dtype, input_shape, layout
) )
ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
primitive_out = tex.grouped_gemm(lhs_list, rhs_list, contracting_dims_list)
for i in range(len(shape_list)): # grouped_gemm does not work with cudaGraph yet, so the jitting will breaks
assert_allclose(primitive_out[i], ref_out[i], dtype=dtype) # To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
# disable cudaGraph, then use the following jitted function
# jitting grouped_gemm
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
# lhs, rhs, group_sizes, contracting_dims,
# )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]]) @pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list, layout_list): def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_gemm yet")
fwd_dtype, bwd_dtype = fwd_bwd_dtype fwd_dtype, bwd_dtype = fwd_bwd_dtype
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False scaling_mode=scaling_mode,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
is_2x2x=False,
n_groups=input_shape[0],
) )
# quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype
# We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype
quantizer_set.kernel.q_dtype = bwd_dtype
for quantizer in quantizer_set.kernel.quantizers:
quantizer.q_dtype = bwd_dtype
out_dtype = jnp.bfloat16 out_dtype = jnp.bfloat16
lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input( lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
out_dtype, shape_list, layout_list out_dtype, input_shape, layout
) )
q_lhs_list = [] ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
q_rhs_list = []
for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list): # jitting grouped_gemm
# quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to # prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))(
# test the case where lhs and rhs have different q_dtypes # lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
q_lhs, q_rhs = _quantize_gemm_pair( # )
lhs, rhs, contracting_dims, quantizer_set.x, quantizer_set.dgrad
)
q_lhs_list.append(q_lhs)
q_rhs_list.append(q_rhs)
ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list) prim_out = tex.grouped_gemm(
primitive_out = tex.grouped_gemm(q_lhs_list, q_rhs_list, contracting_dims_list) lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
)
allclose_dtype = jnp.float8_e4m3fn allclose_dtype = jnp.float8_e4m3fn
if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2: if jnp.float8_e5m2 in fwd_bwd_dtype:
allclose_dtype = jnp.float8_e5m2 allclose_dtype = jnp.float8_e5m2
for i in range(len(shape_list)):
assert_allclose(primitive_out[i], ref_out[i], dtype=allclose_dtype)
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype)
def test_grouped_dense_grad_fp16(self, dtype, shape_list):
group_size = len(shape_list) def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims):
layout_list = ["NN" for _ in range(group_size)] out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input( def _primitive_sum_grouped_dense(
dtype, shape_list, layout_list self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set
):
out = grouped_dense(
x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
) )
bias_list = [] return jnp.sum(jnp.asarray(out))
key = jax.random.PRNGKey(1)
for shape in shape_list:
n = shape[1]
bias = jax.random.uniform(key, n, dtype=dtype)
bias_list.append(bias)
def ref_func(x_list, kernel_list, bias_list, contracting_dims_list):
out_list = []
for i in range(len(x_list)):
out_list.append(
dense(
x_list[i],
kernel_list[i],
bias_list[i],
contracting_dims=contracting_dims_list[i],
)
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
def primitive_func(x_list, kernel_list, bias_list, contracting_dims_list): @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
out_list = grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list) def test_grouped_dense_grad_fp16(self, dtype, input_shape):
out_sum_list = [jnp.sum(out) for out in out_list] x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
return jnp.sum(jnp.asarray(out_sum_list)) dtype,
input_shape,
with_bias=True,
)
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) # jitting the grouped_dense
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
# static_argnums=(4,))
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func( ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
x_list, kernel_list, bias_list, contracting_dims_list x, kernel, bias, group_sizes, contracting_dims
) )
primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = ( prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
value_n_grad_primitive_func(x_list, kernel_list, bias_list, contracting_dims_list) x, kernel, bias, group_sizes, contracting_dims
) )
assert_allclose(primitive_out_mean, ref_out_mean, dtype=dtype) assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype)
for i in range(group_size): assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype)
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=dtype) assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest.mark.parametrize(
"fwd_bwd_dtype",
[(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list): def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
group_size = len(shape_list) if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
layout_list = ["NN" for _ in range(group_size)] pytest.skip("MXFP8 is not supported in grouped_dense yet")
fwd_dtype, bwd_dtype = fwd_bwd_dtype
if fwd_dtype == jnp.float8_e5m2:
pytest.skip("We never use E5M2 for fwd_dtype in training")
# Question: should we use different quantizers for different groups?
ref_quantizer_set_list = []
quantizer_set_list = []
for _ in range(group_size):
ref_quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
)
ref_quantizer_set_list.append(ref_quantizer_set)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
)
quantizer_set_list.append(quantizer_set)
out_dtype = jnp.bfloat16 fwd_dtype, bwd_dtype = fwd_bwd_dtype
x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input( dtype = jnp.bfloat16
out_dtype, shape_list, layout_list x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
dtype,
input_shape,
with_bias=True,
) )
bias_list = []
key = jax.random.PRNGKey(1)
for shape in shape_list:
n = shape[1]
bias = jax.random.uniform(key, n, dtype=out_dtype)
bias_list.append(bias)
def ref_func(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
out_list = []
for i in range(len(x_list)):
out_list.append(
dense(
x_list[i],
kernel_list[i],
bias_list[i],
contracting_dims=contracting_dims_list[i],
quantizer_set=quantizer_set_list[i],
)
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
def primitive_func( quantizer_set = QuantizerFactory.create_set(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list scaling_mode=scaling_mode,
): fwd_dtype=fwd_dtype,
out_list = grouped_dense( bwd_dtype=bwd_dtype,
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list is_2x2x=True,
) n_groups=group_sizes.size,
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func(
x_list, kernel_list, bias_list, contracting_dims_list, ref_quantizer_set_list
) )
primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = ( value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
value_n_grad_primitive_func(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list # jitting the grouped_dense
) # value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
# static_argnums=(4,))
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
x,
kernel,
bias,
group_sizes,
contracting_dims,
)
prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set
) )
allclose_dtype = jnp.float8_e4m3fn assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype)
if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2: assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
allclose_dtype = jnp.float8_e5m2 assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
assert_allclose(primitive_out_mean, ref_out_mean, dtype=allclose_dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
for i in range(group_size):
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
"""
...@@ -68,6 +68,7 @@ class TestDistributedSelfAttn: ...@@ -68,6 +68,7 @@ class TestDistributedSelfAttn:
batch, seqlen, num_head, hidden = data_shape batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available( if not is_fused_attn_kernel_available(
is_training,
dtype, dtype,
dtype, dtype,
QKVLayout.BS3HD, QKVLayout.BS3HD,
...@@ -79,6 +80,7 @@ class TestDistributedSelfAttn: ...@@ -79,6 +80,7 @@ class TestDistributedSelfAttn:
seqlen, seqlen,
seqlen, seqlen,
hidden, hidden,
hidden,
None, # no window None, # no window
): ):
pytest.skip("No FusedAttn backend found") pytest.skip("No FusedAttn backend found")
...@@ -98,6 +100,7 @@ class TestDistributedSelfAttn: ...@@ -98,6 +100,7 @@ class TestDistributedSelfAttn:
num_head, num_head,
num_head, num_head,
hidden, hidden,
hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -214,6 +217,7 @@ class TestDistributedCrossAttn: ...@@ -214,6 +217,7 @@ class TestDistributedCrossAttn:
batch, seqlen, num_head, hidden = data_shape batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available( if not is_fused_attn_kernel_available(
is_training,
dtype, dtype,
dtype, dtype,
QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BS2HD,
...@@ -225,6 +229,7 @@ class TestDistributedCrossAttn: ...@@ -225,6 +229,7 @@ class TestDistributedCrossAttn:
seqlen, seqlen,
seqlen, seqlen,
hidden, hidden,
hidden,
None, # no window None, # no window
): ):
pytest.skip("No FusedAttn backend found") pytest.skip("No FusedAttn backend found")
...@@ -237,6 +242,7 @@ class TestDistributedCrossAttn: ...@@ -237,6 +242,7 @@ class TestDistributedCrossAttn:
num_head, num_head,
num_head, num_head,
hidden, hidden,
hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -289,6 +295,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -289,6 +295,7 @@ class TestDistributedContextParallelSelfAttn:
cp_strategy, cp_strategy,
use_shardy, use_shardy,
use_scan_ring=False, use_scan_ring=False,
window_size=None,
): ):
if qkv_layout.is_thd(): if qkv_layout.is_thd():
if cp_strategy == CPStrategy.ALL_GATHER: if cp_strategy == CPStrategy.ALL_GATHER:
...@@ -326,6 +333,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -326,6 +333,7 @@ class TestDistributedContextParallelSelfAttn:
num_head, num_head,
num_kv_heads, num_kv_heads,
hidden, hidden,
hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -333,7 +341,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -333,7 +341,7 @@ class TestDistributedContextParallelSelfAttn:
is_training, is_training,
qkv_layout, qkv_layout,
bias_shape, bias_shape,
None, window_size,
SeqDescFormat.SegmentIDs, SeqDescFormat.SegmentIDs,
number_of_devices=device_count, number_of_devices=device_count,
mesh_shape=mesh_shape, mesh_shape=mesh_shape,
...@@ -345,6 +353,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -345,6 +353,7 @@ class TestDistributedContextParallelSelfAttn:
def check_has_backend_for_mask(mask_type): def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available( return is_fused_attn_kernel_available(
is_training,
dtype, dtype,
dtype, dtype,
qkv_layout, qkv_layout,
...@@ -356,6 +365,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -356,6 +365,7 @@ class TestDistributedContextParallelSelfAttn:
seqlen, seqlen,
seqlen, seqlen,
hidden, hidden,
hidden,
None, None,
) # no SWA for CP ) # no SWA for CP
...@@ -476,6 +486,13 @@ class TestDistributedContextParallelSelfAttn: ...@@ -476,6 +486,13 @@ class TestDistributedContextParallelSelfAttn:
"use_scan", "use_scan",
[pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")], [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
) )
@pytest.mark.parametrize(
"window_size",
[
pytest.param((-1, -1), id="window_size(-1, -1)"),
pytest.param((20, 0), id="window_size(20, 0)"),
],
)
def test_context_parallel_ring_attn( def test_context_parallel_ring_attn(
self, self,
device_count, device_count,
...@@ -489,7 +506,15 @@ class TestDistributedContextParallelSelfAttn: ...@@ -489,7 +506,15 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
load_balanced, load_balanced,
use_scan, use_scan,
window_size,
): ):
if window_size != (-1, -1) and not qkv_layout.is_thd():
pytest.skip("Sliding window attention is only supported for THD layout")
if window_size != (-1, -1) and qkv_layout.is_thd() and use_scan:
pytest.skip(
"When context parallelism and sliding window attention are used, "
"scanloop is not supported"
)
self.impl_test_context_parallel_attn( self.impl_test_context_parallel_attn(
device_count, device_count,
mesh_shape, mesh_shape,
...@@ -504,6 +529,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -504,6 +529,7 @@ class TestDistributedContextParallelSelfAttn:
CPStrategy.RING, CPStrategy.RING,
use_shardy=False, use_shardy=False,
use_scan_ring=use_scan, use_scan_ring=use_scan,
window_size=window_size,
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -106,7 +106,8 @@ def general_dot_product_attention( ...@@ -106,7 +106,8 @@ def general_dot_product_attention(
softmax_out = softmax_out * multiplier softmax_out = softmax_out * multiplier
context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value) context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
context = jnp.reshape(context, query.shape) context_shape = query.shape[:-1] + (value.shape[-1],)
context = jnp.reshape(context, context_shape)
return context return context
...@@ -294,7 +295,8 @@ class FusedAttnRunner: ...@@ -294,7 +295,8 @@ class FusedAttnRunner:
max_seqlen_kv: int max_seqlen_kv: int
num_heads_q: int num_heads_q: int
num_heads_kv: int num_heads_kv: int
head_dim: int head_dim_qk: int
head_dim_v: int
attn_bias_type: AttnBiasType attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType attn_mask_type: AttnMaskType
dropout_prob: float dropout_prob: float
...@@ -346,7 +348,16 @@ class FusedAttnRunner: ...@@ -346,7 +348,16 @@ class FusedAttnRunner:
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
) )
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
pytest.skip(
"For head_dim_qk != head_dim_v, it is necessary that the QKV layout "
"is either BSHD_BSHD_BSHD or THD_THD_THD"
)
self.backend = FusedAttnHelper( self.backend = FusedAttnHelper(
self.is_training,
self.dtype, self.dtype,
self.dtype, self.dtype,
self.qkv_layout, self.qkv_layout,
...@@ -357,7 +368,8 @@ class FusedAttnRunner: ...@@ -357,7 +368,8 @@ class FusedAttnRunner:
self.num_heads_kv, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_q,
self.max_seqlen_kv, self.max_seqlen_kv,
self.head_dim, self.head_dim_qk,
self.head_dim_v,
(-1, -1) if self.window_size is None else self.window_size, (-1, -1) if self.window_size is None else self.window_size,
).get_fused_attn_backend() ).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
...@@ -390,13 +402,9 @@ class FusedAttnRunner: ...@@ -390,13 +402,9 @@ class FusedAttnRunner:
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim) q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
k_shape = v_shape = ( k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
self.batch_size, v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_v)
self.max_seqlen_kv,
self.num_heads_kv,
self.head_dim,
)
if self.attn_bias_type == AttnBiasType.NO_BIAS: if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None bias_shape = None
...@@ -615,7 +623,7 @@ class FusedAttnRunner: ...@@ -615,7 +623,7 @@ class FusedAttnRunner:
raise ValueError(f"Unknown {self.seq_desc_format=}") raise ValueError(f"Unknown {self.seq_desc_format=}")
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim) self.scaling_factor = 1.0 / sqrt(self.head_dim_qk)
# Setup distributed sharding specs # Setup distributed sharding specs
# Setup shardings for distributed tests # Setup shardings for distributed tests
...@@ -934,9 +942,31 @@ class FusedAttnRunner: ...@@ -934,9 +942,31 @@ class FusedAttnRunner:
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d, dtype", "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype",
[ [
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"), pytest.param(
2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF"
),
pytest.param(
2,
2048,
1024,
12,
12,
64,
64,
jnp.bfloat16,
id="2-2048-1024-12-12-64-64-BF16-CROSS",
),
pytest.param(
2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
),
pytest.param(
4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF"
),
pytest.param(
4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF"
),
pytest.param( pytest.param(
2, 2,
2048, 2048,
...@@ -944,11 +974,13 @@ class FusedAttnRunner: ...@@ -944,11 +974,13 @@ class FusedAttnRunner:
12, 12,
12, 12,
64, 64,
32,
jnp.bfloat16, jnp.bfloat16,
id="2-2048-1024-12-12-64-BF16-CROSS", id="2-2048-1024-12-12-64-32-BF16-CROSS",
),
pytest.param(
2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
), ),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1002,7 +1034,8 @@ class TestFusedAttn: ...@@ -1002,7 +1034,8 @@ class TestFusedAttn:
s_kv, s_kv,
h_q, h_q,
h_kv, h_kv,
d, d_qk,
d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -1027,7 +1060,8 @@ class TestFusedAttn: ...@@ -1027,7 +1060,8 @@ class TestFusedAttn:
s_kv, s_kv,
h_q, h_q,
h_kv, h_kv,
d, d_qk,
d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -1054,7 +1088,8 @@ class TestFusedAttn: ...@@ -1054,7 +1088,8 @@ class TestFusedAttn:
s_kv, s_kv,
h_q, h_q,
h_kv, h_kv,
d, d_qk,
d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -1076,7 +1111,8 @@ class TestFusedAttn: ...@@ -1076,7 +1111,8 @@ class TestFusedAttn:
s_kv, s_kv,
h_q, h_q,
h_kv, h_kv,
d, d_qk,
d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
def pytest_addoption(parser):
parser.addoption(
"--feature_dirs", nargs="+", action="store", default="", help="List of feature directories"
)
parser.addoption(
"--configs_dir",
action="store",
default="",
type=str,
help="Path to the directory with configs.",
)
@pytest.fixture
def feature_dirs(request):
return request.config.getoption("--feature_dirs")
@pytest.fixture
def configs_dir(request):
return request.config.getoption("--configs_dir")
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import tempfile
import functools
import os
import itertools
import random
import argparse
import re
import torch
import torch.distributed as dist
import transformer_engine
import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from test_numerics import (
_emulate_linear,
_init_debug,
disable_fp8_gemms_create_config,
DISABLE_FP8_LAYER_CONFIG,
_cmp,
IN_SIZE,
OUT_SIZE,
_init_model,
SEED,
SEQ_LEN,
BATCH_SIZE,
FP8_RECIPE,
fake_quant_fp8_create_config,
_get_current_scale,
_prepare_per_tensor_scaling_config,
AMAX_HISTORY_LEN,
set_scaling_factors,
set_current_scaling_factors,
)
WORLD_RANK, WORLD_SIZE = None, None
NCCL_WORLD = None
FEATURE_DIRS = None
all_boolean = [True, False]
TEST_NR = 0
def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
if tp_size is None:
tp_size = WORLD_SIZE
tp_rank = WORLD_RANK
torch.manual_seed(weight_seed)
weight = torch.randn((OUT_SIZE, IN_SIZE)).cuda()
torch.manual_seed(data_seed)
in_split_size = IN_SIZE // tp_size
out_split_size = OUT_SIZE // tp_size
x = torch.randn((SEQ_LEN * BATCH_SIZE, IN_SIZE), requires_grad=True).cuda()
if parallel_mode == "row":
x = x[:, tp_rank * in_split_size : (tp_rank + 1) * in_split_size]
x.retain_grad()
with torch.no_grad():
if parallel_mode == "column":
weight = weight[tp_rank * out_split_size : (tp_rank + 1) * out_split_size, :]
else:
weight = weight[:, tp_rank * in_split_size : (tp_rank + 1) * in_split_size]
return x, weight.contiguous()
def _init_model(weight, parallel_mode=None, tp_group=None, name="linear"):
model = transformer_engine.pytorch.Linear(
IN_SIZE,
OUT_SIZE,
name=name,
parallel_mode=parallel_mode,
tp_group=(tp_group or NCCL_WORLD if parallel_mode else None),
)
with torch.no_grad():
model.weight.copy_(weight)
return model
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, dim, group=None):
if group is None:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
else:
world_size = torch.distributed.get_world_size(group=group)
rank = torch.distributed.get_rank(group=group)
dist.barrier()
# Create a list to gather tensors from all processes
y_list = [torch.zeros_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(y_list, tensor, group=group)
# Save the world size and rank for backward computation
ctx.world_size = world_size
ctx.rank = rank
ctx.dim = dim
# Concatenate the gathered tensors along the feature dimension
y_full = torch.cat(y_list, dim=dim)
return y_full
@staticmethod
def backward(ctx, grad_output):
# Split the gradient output and return the portion corresponding to this rank
grad_input = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)[ctx.rank]
return grad_input, None, None
def _run_forward_backward(x, model, parallel_mode=None, group=None):
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y = model(x)
y.requires_grad_(True)
y.retain_grad()
if parallel_mode == "column":
y = AllGather.apply(y, -1, group)
y.requires_grad_(True)
y.retain_grad()
l = y.sum()
l.backward()
elif parallel_mode == "row":
l = y.sum()
l.backward()
debug_api.step()
return y
def _emulate_linear_distributed(*args, parallel_mode=None, **kwargs):
assert parallel_mode in ["column", "row"]
def split(gradient):
split_size = OUT_SIZE // WORLD_SIZE
gradient = gradient[:, WORLD_RANK * split_size : (WORLD_RANK + 1) * split_size]
return gradient
activation_sync = None
gradient_sync = None
if parallel_mode == "column":
activation_sync = lambda x: AllGather.apply(x, -1)
gradient_sync = split
else:
activation_sync = (
lambda activation: dist.all_reduce(activation, op=dist.ReduceOp.SUM) or activation
)
output = _emulate_linear(
*args, activation_sync=activation_sync, gradient_sync=gradient_sync, **kwargs
)
if parallel_mode == "column":
dist.all_reduce(output["dgrad"], op=dist.ReduceOp.SUM)
return output
def check_debug_log(msg):
with open(f"log/debug_logs/debug_log_globalrank-{WORLD_RANK}.log", "r") as f:
for line in f.readlines():
if msg in line:
return True
return False
def run_debug_test(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank = dist.get_rank()
temp_file_name = None
temp_logdir_name = None
if rank == 0:
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
temp_file_name = temp_file.name
temp_dir_obj = tempfile.TemporaryDirectory()
temp_logdir_name = temp_dir_obj.name
# Store the TemporaryDirectory object to prevent it from being deleted
wrapper.temp_dir_obj = temp_dir_obj
temp_file_name_list = [temp_file_name]
temp_logdir_name_list = [temp_logdir_name]
# Broadcast the temporary file and directory names to all processes
dist.broadcast_object_list(temp_file_name_list, src=0)
dist.broadcast_object_list(temp_logdir_name_list, src=0)
temp_file_name = temp_file_name_list[0]
temp_logdir_name = temp_logdir_name_list[0]
dist.barrier()
config_file = open(temp_file_name, mode="r+", buffering=1)
try:
kwargs["config_file"] = config_file
kwargs["log_dir"] = temp_logdir_name
if rank == 0:
global TEST_NR
print(f"Running test {TEST_NR} {func.__name__} with args = {args}.")
TEST_NR += 1
func(*args, **kwargs)
finally:
if rank == 0 and temp_file_name is not None:
os.unlink(temp_file_name)
debug_api.end_debug()
if rank == 0 and hasattr(wrapper, "temp_dir_obj"):
wrapper.temp_dir_obj.cleanup()
return wrapper
CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
LogFp8TensorStats:
enabled: True
tensors: [activation, gradient, weight]
stats: [underflows%]
start_step : 0
end_step: 1
"""
def _prepare_config_test_log_distributed(config_file):
if WORLD_RANK != 0:
return
config_file.write(CONFIG_LOG_TEST_DISTRIBUTED)
config_file.flush()
def _compute_dynamic_range(tensor):
tensor_abs = tensor.abs()
tensor_abs = tensor_abs[tensor_abs != 0]
if tensor_abs.any():
amin = tensor_abs.min().float()
else:
amin = torch.tensor(1, device=tensor.device).to(torch.float)
amax = tensor_abs.max().float()
if not amax.all():
amax = torch.tensor(1, device=tensor.device).to(torch.float)
dynamic_range = torch.log2(amax) - torch.log2(amin)
return dynamic_range
@run_debug_test
def test_log_distributed(parallel_mode, gather_weight, **kwargs):
_prepare_config_test_log_distributed(kwargs["config_file"])
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
set_weight_tensor_tp_group_reduce(gather_weight)
if WORLD_SIZE % 2 != 0:
return # skip
TP_SIZE = WORLD_SIZE // 2
DP_SIZE = 2
TP_RANK = WORLD_RANK % TP_SIZE
DP_RANK = (WORLD_RANK - TP_RANK) // TP_SIZE
debug_api.set_tensor_reduction_group(NCCL_WORLD)
x, weight = _get_tensors(
parallel_mode,
weight_seed=TP_RANK * 1234,
data_seed=DP_RANK * 1234,
tp_size=TP_SIZE,
tp_rank=TP_RANK,
)
tp_group_ranks = [i for i in range(DP_RANK * TP_SIZE, (DP_RANK + 1) * TP_SIZE)]
tp_group = dist.new_group(ranks=tp_group_ranks)
dp_group_ranks = [i for i in range(TP_RANK, WORLD_SIZE, TP_SIZE)]
dp_group = dist.new_group(ranks=dp_group_ranks)
model = _init_model(weight, parallel_mode=parallel_mode, tp_group=tp_group)
output = _run_forward_backward(x, model, parallel_mode=parallel_mode, group=tp_group)
gathered_activation = AllGather.apply(x.contiguous(), 0)
gathered_weight = AllGather.apply(weight.contiguous(), 0, tp_group)
gathered_gradient = AllGather.apply(output.grad.contiguous(), 0, dp_group)
if parallel_mode == "row":
gathered_gradient = AllGather.apply(gathered_gradient, 0, tp_group)
log_file = kwargs["log_dir"] + "/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log"
dist.barrier()
if WORLD_RANK != 0:
return # stats are gathered on node 0
with open(log_file) as f:
content = f.read()
def get_stat(tensor, stat):
regex = r".*_{tensor}_{stat}\s+.*iteration=(\d+)\s+.*value=([-+]?\d*\.?\d+)".format(
tensor=tensor, stat=stat
)
for line in content.splitlines():
match = re.search(regex, line)
if match:
value = float(match.group(2))
return value
rf = lambda x: round(float(x), 4)
stats = []
tensors = {
"activation": gathered_activation,
"weight": gathered_weight if gather_weight else weight,
"gradient": gathered_gradient,
}
stats = {
"min": torch.min,
"max": torch.max,
"mean": torch.mean,
"std": torch.std,
"l1_norm": lambda x: torch.norm(x, p=1),
"l2_norm": lambda x: torch.norm(x, p=2),
"cur_amax": lambda x: x.abs().max(),
"dynamic_range": _compute_dynamic_range,
}
for stat_key in stats.keys():
for tensor_key in tensors.keys():
torch.testing.assert_close(
get_stat(tensor_key, stat_key),
rf(stats[stat_key](tensors[tensor_key])),
atol=0.0001,
rtol=0.0001,
)
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def test_log_expert_parallel(**kwargs):
"""
This test tests the scenario, when one of the node of data parallel does not invoke the debug layer.
It naturally occurs in the expert parallelism, when one expert doesn't get input on one node,
but gets it on other nodes. If there were all_gather inside forward(), this would result in deadlock.
"""
_prepare_config_test_log_distributed(kwargs["config_file"])
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
debug_api.set_tensor_reduction_group(NCCL_WORLD)
x, weight = _get_tensors(
"row", weight_seed=WORLD_RANK * 1234, data_seed=WORLD_RANK * 1234, tp_size=1, tp_rank=0
) # data parallel
model = _init_model(weight, parallel_mode=None, name="linear1")
model1 = _init_model(weight, parallel_mode=None, name="linear2")
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y1 = model(x)
y2 = model1(x)
y = y1 + y2
y.sum().backward()
debug_api.step()
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y = model(x)
if WORLD_RANK != 0:
y = y + model1(x)
y.sum().backward()
@run_debug_test
def test_disable_fp8_gemms(fprop_fp8, dgrad_fp8, wgrad_fp8, parallel_mode, **kwargs):
disable_fp8_gemms_create_config(fprop_fp8, dgrad_fp8, wgrad_fp8, kwargs["config_file"])
fp8_kwargs = {
"fprop_fp8": fprop_fp8,
"dgrad_fp8": dgrad_fp8,
"wgrad_fp8": wgrad_fp8,
}
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
x, weight = _get_tensors(parallel_mode)
model = _init_model(weight, parallel_mode=parallel_mode)
y = _run_forward_backward(x, model, parallel_mode=parallel_mode)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
x.grad.zero_()
ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs)
_cmp(ground_truth, output)
@run_debug_test
def test_disable_fp8_layer(parallel_mode, **kwargs):
if WORLD_RANK == 0:
kwargs["config_file"].write(DISABLE_FP8_LAYER_CONFIG)
kwargs["config_file"].flush()
dist.barrier()
x, weight = _get_tensors(parallel_mode)
ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode)
x.grad.zero_()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
model = _init_model(weight, parallel_mode)
y = _run_forward_backward(x, model, parallel_mode)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
_cmp(ground_truth, output)
@run_debug_test
def test_per_tensor_scaling(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
parallel_mode,
**kwargs,
):
input_kwargs = {
"fprop_inp": fprop_inp,
"fprop_weight": fprop_weight,
"dgrad_weight": dgrad_weight,
"dgrad_grad": dgrad_grad,
"wgrad_input": wgrad_input,
"wgrad_grad": wgrad_grad,
}
fp8_kwargs = {
"fprop_fp8": True,
"dgrad_fp8": True,
"wgrad_fp8": True,
}
"""
Runs a test to validate per-tensor (current) scaling in FP8 computations.
The function performs warm-up iterations to populate the amax buffer of the model and compute scaling factors based on delayed scaling.
Subsequently, weights and inputs are switched to ensure their current scaling factors differ from those based on delayed scaling;
similarly, the loss is multiplied by a large factor to alter the gradient's magnitude,
creating a discrepancy between the original (delayed) and per-tensor (current) scaling factors.
Finally, a linear pass is emulated, and the results are compared.”
"""
_prepare_per_tensor_scaling_config(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
kwargs["config_file"],
)
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
warmup_input, warmup_weight = _get_tensors(parallel_mode=parallel_mode)
model = _init_model(warmup_weight, parallel_mode=parallel_mode)
# Warmup run to setup amax and scaling factors.
for _ in range(AMAX_HISTORY_LEN):
_run_forward_backward(warmup_input, model, parallel_mode=parallel_mode)
x, weight = _get_tensors(
parallel_mode=parallel_mode, weight_seed=WORLD_RANK * 2137, data_seed=WORLD_RANK * 2137
)
model.weight.data = weight.data
x.retain_grad()
# delayed scaling factor
# need to be collected before forward pass with test data,
# because this forward pass changes scaling factors
set_scaling_factors(model, input_kwargs, fp8_kwargs)
LOSS_MULTIPLIER = 100
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y = model(x)
model.zero_grad()
if parallel_mode == "column":
y = AllGather.apply(y, -1)
y.retain_grad()
(
LOSS_MULTIPLIER * y.sum()
).backward() # Loss multiplication to change gradient's order of magintude
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
# per tensor - current - scaling factors
# need to be collected after forward pass with test data,
# because gradient(y.grad) cannot be accessed before forward,
# but it needs to be collected.
set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs)
ground_truth = _emulate_linear_distributed(
x, weight, parallel_mode=parallel_mode, loss_multiplier=LOSS_MULTIPLIER, **fp8_kwargs
)
_cmp(ground_truth, output)
@run_debug_test
def test_fake_quant_fp8(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
parallel_mode,
**kwargs,
):
fp8_kwargs = {
"fprop_input_fake_quant": fprop_inp,
"fprop_weight_fake_quant": fprop_weight,
"dgrad_gradient_fake_quant": dgrad_grad,
"dgrad_weight_fake_quant": dgrad_weight,
"wgrad_gradient_fake_quant": wgrad_grad,
"wgrad_input_fake_quant": wgrad_input,
"fprop_fp8": not (fprop_inp or fprop_weight),
"dgrad_fp8": not (dgrad_weight or dgrad_grad),
"wgrad_fp8": not (wgrad_grad or wgrad_input),
}
if WORLD_RANK == 0:
fake_quant_fp8_create_config(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
kwargs["config_file"],
)
dist.barrier()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
x, weight = _get_tensors(parallel_mode)
model = _init_model(weight, parallel_mode)
y = _run_forward_backward(x, model, parallel_mode)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
fp8_kwargs["fprop_input_scale"] = (
_get_current_scale(x, fprop_inp) if not fp8_kwargs["fprop_fp8"] else None
)
fp8_kwargs["fprop_weight_scale"] = (
_get_current_scale(weight, fprop_weight) if not fp8_kwargs["fprop_fp8"] else None
)
fp8_kwargs["dgrad_gradient_scale"] = (
_get_current_scale(y.grad, dgrad_grad) if not fp8_kwargs["dgrad_fp8"] else None
)
fp8_kwargs["dgrad_weight_scale"] = (
_get_current_scale(weight, dgrad_weight) if not fp8_kwargs["dgrad_fp8"] else None
)
fp8_kwargs["wgrad_gradient_scale"] = (
_get_current_scale(y.grad, wgrad_grad) if not fp8_kwargs["wgrad_fp8"] else None
)
fp8_kwargs["wgrad_input_scale"] = (
_get_current_scale(x, wgrad_input) if not fp8_kwargs["wgrad_fp8"] else None
)
ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs)
_cmp(ground_truth, output)
def _init_distributed():
global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, FP8
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node
assert LOCAL_SIZE <= torch.cuda.device_count()
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
dist_init_kwargs["init_method"] = "env://"
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
torch.cuda.set_device(LOCAL_RANK)
dist.init_process_group(**dist_init_kwargs)
NCCL_WORLD = dist.new_group(backend="nccl")
WORLD_SIZE = dist.get_world_size()
def _run_test_with_combinations(
test_function, values_list, num_repeat, extra_args, sample_size=None
):
combinations = itertools.product(values_list, repeat=num_repeat)
total_combinations = itertools.product(combinations, extra_args)
if sample_size is not None:
total_combinations = random.sample(list(total_combinations), sample_size)
for comb, arg in total_combinations:
test_function(*comb, arg)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--feature_dirs", type=str)
args = parser.parse_args()
FEATURE_DIRS = args.feature_dirs
random.seed(SEED)
_init_distributed()
test_log_expert_parallel()
for parallel_mode in ["column", "row"]:
for gather_weight in [True, False]:
test_log_distributed(parallel_mode, gather_weight)
for parallel_mode in ["row", "column"]:
test_disable_fp8_layer(parallel_mode)
# test_disable_fp8_gemms
_run_test_with_combinations(
test_disable_fp8_gemms, all_boolean, num_repeat=3, extra_args=["column", "row"]
)
# test_fake_quant_fp8
dtype_options = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None]
_run_test_with_combinations(
test_fake_quant_fp8,
dtype_options,
num_repeat=6,
extra_args=["column", "row"],
sample_size=20,
)
_run_test_with_combinations(
test_per_tensor_scaling,
all_boolean,
num_repeat=6,
extra_args=["column"],
sample_size=20,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
import nvdlfw_inspect.api as debug_api
try:
import transformer_engine
import transformer_engine_torch as tex
except (ImportError, ModuleNotFoundError):
print("Could not find TransformerEngine package.")
exit(1)
def test_transformer_engine_no_config(feature_dirs):
debug_api.initialize("", feature_dirs=feature_dirs)
try:
tensor = torch.rand(24, 2046).cuda()
# FP8 enabled - true by the default
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
# modify_tensor_enabled - False by default
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
# inspect_tensor_enabled - False by default
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.attn.qkv", tensor_name="activation", iteration=0
)
# inspect_tensor_postquantize - False by default
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
finally:
debug_api.end_debug()
def test_disable_fp8_gemm(configs_dir, feature_dirs):
try:
debug_api.initialize(configs_dir + "disable_fp8_gemms.yaml", feature_dirs=feature_dirs)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
# caching
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
finally:
debug_api.end_debug()
def test_disable_fp8_layer(configs_dir, feature_dirs):
try:
debug_api.initialize(configs_dir + "disable_fp8_layer.yaml", feature_dirs=feature_dirs)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="fprop", iteration=0
)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", iteration=0
)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
finally:
debug_api.end_debug()
def test_per_tensor_scaling(configs_dir, feature_dirs):
try:
debug_api.initialize(configs_dir + "per_tensor_scaling.yaml", feature_dirs=feature_dirs)
tensor = torch.rand(24, 2046).cuda()
# check modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
)
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0
)
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0
)
# check modify_tensor
default_quantizer1 = Float8Quantizer(
scale=torch.tensor([1]).cuda(),
amax=torch.tensor([0]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
default_quantizer2 = Float8Quantizer(
scale=torch.tensor([1]).cuda(),
amax=torch.tensor([0]).cuda(),
fp8_dtype=tex.DType.kFloat8E5M2,
)
output1 = debug_api.transformer_engine.modify_tensor(
layer_name="decoder.1.mlp.fc1",
gemm="fprop",
tensor_name="activation",
default_quantizer=default_quantizer1,
iteration=0,
tensor=tensor,
)
assert type(output1) == Float8Tensor
assert output1._fp8_dtype == tex.DType.kFloat8E4M3
output2 = debug_api.transformer_engine.modify_tensor(
"decoder.1.mlp.fc1",
gemm="dgrad",
tensor=tensor,
tensor_name="gradient",
default_quantizer=default_quantizer2,
iteration=0,
)
assert type(output2) == Float8Tensor
assert output2._fp8_dtype == tex.DType.kFloat8E5M2
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1",
gemm="wgrad",
tensor_name="gradient",
iteration=0,
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc4",
gemm="fprop",
tensor_name="activation",
iteration=0,
)
finally:
debug_api.end_debug()
def test_fake_quant(configs_dir, feature_dirs):
try:
debug_api.initialize(
configs_dir + "fake_quantization_config.yaml", feature_dirs=feature_dirs
)
tensor = torch.rand(24, 2046).cuda()
# modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
)
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
)
# modify_tensor
debug_api.transformer_engine.modify_tensor(
"decoder.1.mlp.fc1",
gemm="fprop",
tensor=tensor,
tensor_name="activation",
iteration=0,
default_quantizer=None,
)
debug_api.transformer_engine.modify_tensor(
"decoder.1.mlp.fc1",
gemm="dgrad",
tensor=tensor,
tensor_name="gradient",
iteration=0,
default_quantizer=None,
)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0
)
# caching
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0
)
finally:
debug_api.end_debug()
def test_statistics_collection(configs_dir, feature_dirs):
try:
debug_api.initialize(
config_file=configs_dir + "stats_collection_test_config.yaml",
feature_dirs=feature_dirs,
default_logging_enabled=False,
)
tensor = torch.randn((100, 100, 5)).cuda()
tensor_fp8 = Float8Tensor(
data=tensor.to(torch.uint8).cuda(),
fp8_scale_inv=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=tensor.shape,
dtype=torch.float32,
)
def log():
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
return STATS_BUFFERS.log_stats()
def assert_empty():
stats = log()
assert len(stats) == 0
# TE tensor stats --
debug_api.transformer_engine.inspect_tensor(
"decoder.1.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=200,
tp_group=None,
)
stats = log()
assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max()
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", iteration=201
)
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="activation", iteration=200
)
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)
expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5)
expected_overflows = (tensor_fp8._data == 126).sum() * 100 / (100 * 100 * 5)
# TE FP8 tensor stats --
assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
)
debug_api.transformer_engine.inspect_tensor_postquantize(
"decoder.1.mlp.fc1",
tensor=tensor_fp8,
tensor_name="gradient",
iteration=200,
rowwise=True,
tp_group=None,
)
stats = log()
torch.testing.assert_close(
stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows
)
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201
)
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
)
# Second config in same yaml
tensor = torch.rand((100, 100, 5))
debug_api.transformer_engine.inspect_tensor(
"decoder.6.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=200,
tp_group=None,
)
stats = log()
stats_names = [x[3] for x in stats.keys()]
all(s in stats_names for s in ["cur_amax", "dynamic_range", "mean", "std", "l1_norm"])
assert stats[("decoder.6.mlp.fc1", "activation", "mean", 200)] == tensor.mean()
debug_api.transformer_engine.inspect_tensor(
"decoder.7.mlp.fc1",
tensor=tensor,
tensor_name="weight",
iteration=200,
tp_group=None,
)
stats = log()
stats_names = [x[3] for x in stats.keys()]
all(s in stats_names for s in ["mean", "std", "l1_norm", "min", "max"])
assert stats[("decoder.7.mlp.fc1", "weight", "max", 200)] == tensor.max()
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.7.mlp.fc1", tensor_name="weight", iteration=201
)
assert_empty()
finally:
debug_api.end_debug()
def test_statistics_multi_run(configs_dir, feature_dirs):
try:
debug_api.initialize(
config_file=configs_dir + "stats_collection_test_config.yaml",
feature_dirs=feature_dirs,
default_logging_enabled=False,
)
def feed(tensor, tensor_fp8):
debug_api.transformer_engine.inspect_tensor(
"decoder.5.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=1,
tp_group=None,
)
debug_api.transformer_engine.inspect_tensor_postquantize(
"decoder.5.mlp.fc1",
tensor=tensor_fp8,
tensor_name="activation",
iteration=1,
rowwise=True,
tp_group=None,
)
def log_stats():
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
return STATS_BUFFERS.log_stats()
def fp8_tensor(t):
return Float8Tensor(
data=t.to(torch.uint8).cuda(),
fp8_scale_inv=torch.ones([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=t.shape,
dtype=torch.float32,
)
shape = [1024, 1024]
tensors = [torch.randn(shape) for _ in range(2)]
tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)]
feed(tensors[0], tensors_fp8[0])
feed(tensors[1], tensors_fp8[1])
stats1 = log_stats()
tensor2 = torch.cat((tensors[0], tensors[1])).cuda()
fp8tensor2 = fp8_tensor(tensor2)
feed(tensor2, fp8tensor2)
stats2 = log_stats()
assert len(stats1.keys()) > 0
for k in stats1.keys():
torch.testing.assert_close(stats1[k], stats2[k])
finally:
debug_api.end_debug()
if __name__ == "__main__":
pass
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib, os
from nvdlfw_inspect.config_manager import ConfigManager
import nvdlfw_inspect.api as debug_api
try:
import transformer_engine
from transformer_engine.debug.features.api import TEConfigAPIMapper
except (ImportError, ModuleNotFoundError):
print("Could not find TransformerEngine debug module.")
exit(1)
def test_transformer_engine_config_parsing(feature_dirs):
debug_api.initialize(
config_file=pathlib.Path(__file__).resolve().parent
/ "test_configs/tensor_manipulation_transformer_engine.yaml",
feature_dirs=feature_dirs,
log_dir="./log",
)
cfg_fc1 = ConfigManager.get_config_for_layer("decoder.1.mlp.fc1")["transformer_engine"]
cfg_fc2 = ConfigManager.get_config_for_layer("decoder.1.mlp.fc2")["transformer_engine"]
assert cfg_fc1 and cfg_fc2
gemm_parsing = True
tensor_parsing = True
# Per tensor scaling set for dgrad, filter based on gemm
ret, _ = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="wgrad",
tensor_name="activation",
)
assert not ret
# per tensor scaling set for gradient, filter based on tensor name
ret, _ = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="activation",
)
assert not ret
ret, parsed_cfg_fc1 = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="gradient",
)
assert ret
assert parsed_cfg_fc1 == {"gemm": "dgrad", "tensor": "gradient"}
# Test tensor struct
ret, parsed_cfg_fc1_act = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="activation",
)
ret, parsed_cfg_fc1_wei = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="weight",
)
assert ret
assert parsed_cfg_fc1_act == {
"gemm": "fprop",
"tensor": "activation",
"quant_format": "FP8E4M3",
}
assert parsed_cfg_fc1_wei == {
"gemm": "fprop",
"tensor": "weight",
"quant_format": "FP8E4M3",
}
# Test gemms struct
ret, parsed_cfg_fc2_grad = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="gradient",
)
assert ret
assert parsed_cfg_fc2_grad == {"gemm": "dgrad", "tensor": "gradient", "quant_format": "FP8E5M2"}
ret, parsed_cfg_fc2_wei = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="weight",
)
assert ret
assert parsed_cfg_fc2_wei == {"gemm": "dgrad", "tensor": "weight", "quant_format": "FP8E5M2"}
# Test gemm + tensor struct
ret, parsed_cfg_fc2_fprop_act = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="activation",
)
assert ret
assert parsed_cfg_fc2_fprop_act == {"gemm": "fprop", "tensor": "activation"}
ret, parsed_cfg_fc2_fprop_wei = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="weight",
)
assert ret
assert parsed_cfg_fc2_fprop_wei == {"gemm": "fprop", "tensor": "weight"}
ret, parsed_cfg_fc2_wgrad_act = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="wgrad",
tensor_name="activation",
)
assert ret
assert parsed_cfg_fc2_wgrad_act == {"gemm": "wgrad", "tensor": "activation"}
ret, parsed_cfg_fc2_wgrad_grad = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="wgrad",
tensor_name="gradient",
)
assert ret
assert parsed_cfg_fc2_wgrad_grad == {"gemm": "wgrad", "tensor": "gradient"}
ConfigManager.reset()
test_disable_fp8_gemm_1:
enabled: True
layers:
layer_types: [qkv, fc2]
transformer_engine:
DisableFP8GEMM:
enabled: True
gemms: [dgrad, wgrad]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment