Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2b05e121
Commit
2b05e121
authored
Jun 17, 2025
by
yuguo
Browse files
Merge commit '
a69692ac
' of...
Merge commit '
a69692ac
' of
https://github.com/NVIDIA/TransformerEngine
parents
0fd441c2
a69692ac
Changes
245
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1790 additions
and
362 deletions
+1790
-362
qa/L0_jax_distributed_unittest/test.sh
qa/L0_jax_distributed_unittest/test.sh
+4
-4
qa/L0_jax_unittest/test.sh
qa/L0_jax_unittest/test.sh
+3
-3
qa/L0_pytorch_debug_unittest/test.sh
qa/L0_pytorch_debug_unittest/test.sh
+26
-0
qa/L0_pytorch_lint/test.sh
qa/L0_pytorch_lint/test.sh
+1
-1
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+1
-0
qa/L1_pytorch_distributed_unittest/test.sh
qa/L1_pytorch_distributed_unittest/test.sh
+14
-0
qa/L2_jax_unittest/test.sh
qa/L2_jax_unittest/test.sh
+6
-6
setup.py
setup.py
+19
-48
tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
+1
-1
tests/cpp/operator/test_normalization.h
tests/cpp/operator/test_normalization.h
+2
-1
tests/cpp/test_common.cu
tests/cpp/test_common.cu
+48
-30
tests/cpp/test_common.h
tests/cpp/test_common.h
+63
-17
tests/jax/test_custom_call_compute.py
tests/jax/test_custom_call_compute.py
+288
-230
tests/jax/test_distributed_fused_attn.py
tests/jax/test_distributed_fused_attn.py
+27
-1
tests/jax/test_fused_attn.py
tests/jax/test_fused_attn.py
+56
-20
tests/pytorch/debug/conftest.py
tests/pytorch/debug/conftest.py
+27
-0
tests/pytorch/debug/run_distributed.py
tests/pytorch/debug/run_distributed.py
+647
-0
tests/pytorch/debug/test_api_features.py
tests/pytorch/debug/test_api_features.py
+398
-0
tests/pytorch/debug/test_config.py
tests/pytorch/debug/test_config.py
+151
-0
tests/pytorch/debug/test_configs/disable_fp8_gemms.yaml
tests/pytorch/debug/test_configs/disable_fp8_gemms.yaml
+8
-0
No files found.
qa/L0_jax_distributed_unittest/test.sh
View file @
2b05e121
...
...
@@ -24,10 +24,10 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_multigpu_encoder.xml
$TE_PATH
/examples/jax/encoder/test_multigpu_encoder.py
||
test_fail
"test_multigpu_encoder.py"
wait
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_model_parallel_encoder.xml
$TE_PATH
/examples/jax/encoder/test_model_parallel_encoder.py
||
test_fail
"test_model_parallel_encoder.py"
wait
#
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
#
wait
#
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
#
wait
.
$TE_PATH
/examples/jax/encoder/run_test_multiprocessing_encoder.sh
||
test_fail
"run_test_multiprocessing_encoder.sh"
if
[
$RET
-ne
0
]
;
then
...
...
qa/L0_jax_unittest/test.sh
View file @
2b05e121
...
...
@@ -27,9 +27,6 @@ mkdir -p "$XML_LOG_DIR"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_jax_not_distributed.xml
$TE_PATH
/tests/jax
-k
'not distributed'
--ignore
=
$TE_PATH
/tests/jax/test_helper.py
||
test_fail
"tests/jax/*not_distributed_*"
# Test without custom calls
NVTE_CUSTOM_CALLS_RE
=
""
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_custom_call_compute.xml
$TE_PATH
/tests/jax/test_custom_call_compute.py
||
test_fail
"test_custom_call_compute.py"
pip3
install
-r
$TE_PATH
/examples/jax/mnist/requirements.txt
||
error_exit
"Failed to install mnist requirements"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_mnist.xml
$TE_PATH
/examples/jax/mnist
||
test_fail
"mnist"
...
...
@@ -37,6 +34,9 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_single_gpu_encoder.xml
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py"
# Test without custom calls
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS_RE
=
""
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_single_gpu_encoder.xml
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py without custom calls"
if
[
$RET
-ne
0
]
;
then
echo
"Error: some sub-tests failed:
$FAILED_CASES
"
...
...
qa/L0_pytorch_debug_unittest/test.sh
0 → 100644
View file @
2b05e121
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
:
${
TE_PATH
:
=/opt/transformerengine
}
:
${
NVTE_TEST_NVINSPECT_FEATURE_DIRS
:
=
$TE_PATH
/transformer_engine/debug/features
}
:
${
NVTE_TEST_NVINSPECT_CONFIGS_DIR
:
=
$TE_PATH
/tests/pytorch/debug/test_configs/
}
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
:
${
NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
:
=
$TE_PATH
/tests/pytorch/debug/test_configs/dummy_feature.yaml
}
FAIL
=
0
pip
install
pytest
==
8.2.1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_sanity.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_config.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_numerics.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
NVTE_TORCH_COMPILE
=
0 pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_api_features.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED
=
True
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 pytest
-v
-s
$TE_PATH
/tests/pytorch/test_numerics.py
||
FAIL
=
1
exit
$FAIL
qa/L0_pytorch_lint/test.sh
View file @
2b05e121
...
...
@@ -20,5 +20,5 @@ if [ -z "${CPP_ONLY}" ]
then
cd
$TE_PATH
echo
"Checking Python files"
python3
-m
pylint
--recursive
=
y transformer_engine/common transformer_engine/pytorch
python3
-m
pylint
--recursive
=
y transformer_engine/common transformer_engine/pytorch
transformer_engine/debug
fi
qa/L0_pytorch_unittest/test.sh
View file @
2b05e121
...
...
@@ -47,6 +47,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entro
NVTE_FLASH_ATTN
=
0 python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_cpu_offloading.xml
$TE_PATH
/tests/pytorch/test_cpu_offloading.py
||
test_fail
"test_cpu_offloading.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_attn.xml
$TE_PATH
/tests/pytorch/fused_attn/test_fused_attn.py
||
test_fail
"test_fused_attn.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_kv_cache.xml
$TE_PATH
/tests/pytorch/fused_attn/test_kv_cache.py
||
test_fail
"test_kv_cache.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_hf_integration.xml
$TE_PATH
/tests/pytorch/test_hf_integration.py
||
test_fail
"test_hf_integration.py"
if
[
"
$RET
"
-ne
0
]
;
then
echo
"Error in the following test cases:
$FAILED_CASES
"
...
...
qa/L1_pytorch_distributed_unittest/test.sh
View file @
2b05e121
...
...
@@ -20,6 +20,7 @@ FAILED_CASES=""
:
${
XML_LOG_DIR
:
=/logs
}
mkdir
-p
"
$XML_LOG_DIR
"
pip3
install
pytest
==
8.2.1
||
error_exit
"Failed to install pytest"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_numerics.xml
$TE_PATH
/tests/pytorch/distributed/test_numerics.py
||
test_fail
"test_numerics.py"
...
...
@@ -30,6 +31,19 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_attn_with_cp.xml
$TE_PATH
/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
||
test_fail
"test_fused_attn_with_cp.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_cast_master_weights_to_fp8.xml
$TE_PATH
/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
||
test_fail
"test_cast_master_weights_to_fp8.py"
# debug tests
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
:
${
NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
:
=
$TE_PATH
/tests/pytorch/debug/test_configs/dummy_feature.yaml
}
:
${
NVTE_TEST_NVINSPECT_FEATURE_DIRS
:
=
$TE_PATH
/transformer_engine/debug/features
}
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_distributed.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
test_fail
"debug test_distributed.py"
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED
=
True
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
pytest
-v
-s
$TE_PATH
/tests/pytorch/distributed/test_numerics.py
||
test_fail
"debug test_numerics.py"
if
[
"
$RET
"
-ne
0
]
;
then
echo
"Error in the following test cases:
$FAILED_CASES
"
exit
1
...
...
qa/L2_jax_unittest/test.sh
View file @
2b05e121
...
...
@@ -25,18 +25,18 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
:
${
XML_LOG_DIR
:
=/logs
}
mkdir
-p
"
$XML_LOG_DIR
"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_jax_not_distributed.xml
$TE_PATH
/tests/jax
-k
'not distributed'
||
test_fail
"tests/jax/*not_distributed_*"
# Test without custom calls
NVTE_JAX_UNITTEST_LEVEL
=
"L2"
NVTE_CUSTOM_CALLS_RE
=
""
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_custom_call_compute.xml
$TE_PATH
/tests/jax/test_custom_call_compute.py
||
test_fail
"test_custom_call_compute.py"
NVTE_JAX_UNITTEST_LEVEL
=
"L2"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_jax_not_distributed.xml
$TE_PATH
/tests/jax
-k
'not distributed'
||
test_fail
"tests/jax/*not_distributed_*"
pip3
install
-r
$TE_PATH
/examples/jax/mnist/requirements.txt
||
error_exit
"Failed to install mnist requirements"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_mnist.xml
$TE_PATH
/examples/jax/mnist
||
test_fail
"mnist"
NVTE_JAX_UNITTEST_LEVEL
=
"L2"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_mnist.xml
$TE_PATH
/examples/jax/mnist
||
test_fail
"mnist"
pip3
install
-r
$TE_PATH
/examples/jax/encoder/requirements.txt
||
error_exit
"Failed to install encoder requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_single_gpu_encoder.xml
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py"
NVTE_JAX_UNITTEST_LEVEL
=
"L2"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_single_gpu_encoder.xml
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py"
# Test without custom calls
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS_RE
=
""
NVTE_JAX_UNITTEST_LEVEL
=
"L2"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_single_gpu_encoder.xml
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py"
if
[
$RET
-ne
0
]
;
then
echo
"Error: some sub-tests failed:
$FAILED_CASES
"
...
...
setup.py
View file @
2b05e121
...
...
@@ -19,11 +19,7 @@ from build_tools.te_version import te_version
from
build_tools.utils
import
(
rocm_build
,
cuda_archs
,
found_cmake
,
found_ninja
,
found_pybind11
,
get_frameworks
,
install_and_import
,
remove_dups
,
)
...
...
@@ -38,7 +34,6 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1"
if
"pytorch"
in
frameworks
:
from
torch.utils.cpp_extension
import
BuildExtension
elif
"jax"
in
frameworks
:
install_and_import
(
"pybind11[global]"
)
from
pybind11.setup_helpers
import
build_ext
as
BuildExtension
...
...
@@ -87,6 +82,11 @@ def setup_common_extension() -> CMakeExtension:
if
bool
(
int
(
os
.
getenv
(
"NVTE_BUILD_ACTIVATION_WITH_FAST_MATH"
,
"0"
))):
cmake_flags
.
append
(
"-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON"
)
# Add custom CMake arguments from environment variable
nvte_cmake_extra_args
=
os
.
getenv
(
"NVTE_CMAKE_EXTRA_ARGS"
)
if
nvte_cmake_extra_args
:
cmake_flags
.
extend
(
nvte_cmake_extra_args
.
split
())
# Project directory root
root_path
=
Path
(
__file__
).
resolve
().
parent
if
rocm_build
():
...
...
@@ -102,22 +102,13 @@ def setup_common_extension() -> CMakeExtension:
)
def
setup_requirements
()
->
Tuple
[
List
[
str
],
List
[
str
],
List
[
str
]]:
def
setup_requirements
()
->
Tuple
[
List
[
str
],
List
[
str
]]:
"""Setup Python dependencies
Returns dependencies for
build,
runtime
,
and testing.
Returns dependencies for runtime and testing.
"""
# Common requirements
setup_reqs
:
List
[
str
]
=
[
"nvidia-cuda-runtime-cu12"
,
"nvidia-cublas-cu12"
,
"nvidia-cudnn-cu12"
,
"nvidia-cuda-cccl-cu12"
,
"nvidia-cuda-nvcc-cu12"
,
"nvidia-nvtx-cu12"
,
"nvidia-cuda-nvrtc-cu12"
,
]
install_reqs
:
List
[
str
]
=
[
"pydantic"
,
"importlib-metadata>=1.0"
,
...
...
@@ -125,32 +116,20 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
]
test_reqs
:
List
[
str
]
=
[
"pytest>=8.2.1"
]
# Requirements that may be installed outside of Python
if
not
found_cmake
():
setup_reqs
.
append
(
"cmake>=3.21"
)
if
not
found_ninja
():
setup_reqs
.
append
(
"ninja"
)
if
not
found_pybind11
():
setup_reqs
.
append
(
"pybind11"
)
# Framework-specific requirements
if
not
bool
(
int
(
os
.
getenv
(
"NVTE_RELEASE_BUILD"
,
"0"
))):
if
"pytorch"
in
frameworks
:
setup_reqs
.
extend
([
"torch>=2.1"
])
install_reqs
.
extend
([
"torch>=2.1"
])
# install_reqs.append(
# "nvdlfw-inspect @"
# " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
# )
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton")
test_reqs
.
extend
([
"numpy"
,
"torchvision"
])
from
build_tools.pytorch
import
install_requirements
,
test_requirements
install_reqs
.
extend
(
install_requirements
())
test_reqs
.
extend
(
test_requirements
())
if
"jax"
in
frameworks
:
setup_reqs
.
extend
([
"jax[cuda12]"
,
"flax>=0.7.1"
])
install_reqs
.
extend
([
"jax"
,
"flax>=0.7.1"
])
test_reqs
.
extend
([
"numpy"
])
from
build_tools.jax
import
install_requirements
,
test_requirements
install_reqs
.
extend
(
install_requirements
())
test_reqs
.
extend
(
test_requirements
())
return
[
remove_dups
(
reqs
)
for
reqs
in
[
setup_reqs
,
install_reqs
,
test_reqs
]]
return
[
remove_dups
(
reqs
)
for
reqs
in
[
install_reqs
,
test_reqs
]]
if
__name__
==
"__main__"
:
...
...
@@ -167,14 +146,13 @@ if __name__ == "__main__":
ext_modules
=
[]
package_data
=
{}
include_package_data
=
False
setup_requires
=
[]
install_requires
=
([
f
"transformer_engine_cu12==
{
__version__
}
"
],)
extras_require
=
{
"pytorch"
:
[
f
"transformer_engine_torch==
{
__version__
}
"
],
"jax"
:
[
f
"transformer_engine_jax==
{
__version__
}
"
],
}
else
:
setup_requires
,
install_requires
,
test_requires
=
setup_requirements
()
install_requires
,
test_requires
=
setup_requirements
()
ext_modules
=
[
setup_common_extension
()]
package_data
=
{
""
:
[
"VERSION.txt"
]}
include_package_data
=
True
...
...
@@ -219,15 +197,8 @@ if __name__ == "__main__":
long_description_content_type
=
"text/x-rst"
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
,
"bdist_wheel"
:
TimedBdist
},
python_requires
=
">=3.8, <3.13"
,
classifiers
=
[
"Programming Language :: Python :: 3.8"
,
"Programming Language :: Python :: 3.9"
,
"Programming Language :: Python :: 3.10"
,
"Programming Language :: Python :: 3.11"
,
"Programming Language :: Python :: 3.12"
,
],
setup_requires
=
setup_requires
,
python_requires
=
">=3.8"
,
classifiers
=
[
"Programming Language :: Python :: 3"
],
install_requires
=
install_requires
,
license_files
=
(
"LICENSE"
,),
include_package_data
=
include_package_data
,
...
...
tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
View file @
2b05e121
...
...
@@ -375,7 +375,7 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = {
{
256
,
256
},
{
993
,
512
},
{
768
,
1024
},
{
655
36
,
128
},
{
655
04
,
128
},
{
16384
,
1632
},
};
...
...
tests/cpp/operator/test_normalization.h
View file @
2b05e121
...
...
@@ -71,7 +71,8 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
// Remove the use_cudnn check here when it is supported by both backends.
const
bool
zero_centered_gamma_in_weight_dtype
=
use_cudnn
&&
cudnn_zero_centered_gamma_in_weight_dtype
;
if
constexpr
(
std
::
is_same_v
<
InputType
,
fp8e5m2
>
||
std
::
is_same_v
<
InputType
,
fp8e4m3
>
){
if
constexpr
(
std
::
is_same_v
<
InputType
,
fp8e5m2
>
||
std
::
is_same_v
<
InputType
,
fp8e4m3
>
||
std
::
is_same_v
<
InputType
,
fp4e2m1
>
){
compute_t
g
=
static_cast
<
compute_t
>
(
gamma
);
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
...
...
tests/cpp/test_common.cu
View file @
2b05e121
...
...
@@ -45,7 +45,7 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
return
true
;
}
size_t
typeTo
Size
(
DType
type
)
{
size_t
typeTo
NumBits
(
DType
type
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL
(
type
,
T
,
{
return
TypeInfo
<
T
>::
size
;
...
...
@@ -62,7 +62,8 @@ const std::string &typeName(DType type) {
{
DType
::
kBFloat16
,
"bfloat16"
},
{
DType
::
kFloat8E4M3
,
"float8e4m3"
},
{
DType
::
kFloat8E5M2
,
"float8e5m2"
},
{
DType
::
kFloat8E8M0
,
"float8e8m0"
}};
{
DType
::
kFloat8E8M0
,
"float8e8m0"
},
{
DType
::
kFloat4E2M1
,
"float4e2m1"
}};
return
name_map
.
at
(
type
);
}
...
...
@@ -109,9 +110,16 @@ size_t DIVUP(const size_t &x, const size_t &y){
struct
scale_inv_meta
{
std
::
vector
<
size_t
>
shape
;
DType
type
;
size_t
type_size
;
size_t
type_size_bits
;
size_t
bytes
()
const
noexcept
{
return
(
product
(
shape
)
*
type_size_bits
)
/
8
;
}
};
size_t
bytes
(
const
NVTEShape
&
shape
,
const
DType
type
)
{
return
(
product
(
shape
)
*
typeToNumBits
(
type
))
/
8
;
}
NVTEShape
convertShape
(
const
std
::
vector
<
size_t
>&
s
)
{
return
nvte_make_shape
(
s
.
data
(),
s
.
size
());
}
...
...
@@ -122,7 +130,7 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta
ret
;
ret
.
shape
=
{
1
};
ret
.
type
=
DType
::
kFloat32
;
ret
.
type_size
=
sizeof
(
f
loat
);
ret
.
type_size
_bits
=
typeToNumBits
(
DType
::
kF
loat
32
);
return
{
ret
,
ret
};
}
if
(
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
{
...
...
@@ -152,8 +160,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
}
ret_rowwise
.
type
=
DType
::
kFloat8E8M0
;
ret_colwise
.
type
=
DType
::
kFloat8E8M0
;
ret_rowwise
.
type_size
=
sizeof
(
uint8_t
);
ret_colwise
.
type_size
=
sizeof
(
uint8_t
);
ret_rowwise
.
type_size
_bits
=
typeToNumBits
(
DType
::
kFloat8E8M0
);
ret_colwise
.
type_size
_bits
=
typeToNumBits
(
DType
::
kFloat8E8M0
);
return
{
ret_rowwise
,
ret_colwise
};
}
...
...
@@ -179,8 +187,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
}
ret_rowwise
.
type
=
DType
::
kFloat32
;
ret_colwise
.
type
=
DType
::
kFloat32
;
ret_rowwise
.
type_size
=
sizeof
(
f
loat
);
ret_colwise
.
type_size
=
sizeof
(
f
loat
);
ret_rowwise
.
type_size
_bits
=
typeToNumBits
(
DType
::
kF
loat
32
);
ret_colwise
.
type_size
_bits
=
typeToNumBits
(
DType
::
kF
loat
32
);
return
{
ret_rowwise
,
ret_colwise
};
}
...
...
@@ -205,8 +213,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
}
ret_rowwise
.
type
=
DType
::
kFloat32
;
ret_colwise
.
type
=
DType
::
kFloat32
;
ret_rowwise
.
type_size
=
sizeof
(
f
loat
);
ret_colwise
.
type_size
=
sizeof
(
f
loat
);
ret_rowwise
.
type_size
_bits
=
typeToNumBits
(
DType
::
kF
loat
32
);
ret_colwise
.
type_size
_bits
=
typeToNumBits
(
DType
::
kF
loat
32
);
return
{
ret_rowwise
,
ret_colwise
};
}
...
...
@@ -222,8 +230,7 @@ Tensor::Tensor(const std::string& name,
gen_
.
seed
(
seed
);
rowwise_
=
rowwise
;
columnwise_
=
columnwise
;
size_t
s
=
typeToSize
(
type
);
size_t
total_size
=
product
(
shape
)
*
s
;
size_t
total_size
=
bytes
(
shape
,
type
);
void
*
dptr_rowwise
=
nullptr
;
void
*
dptr_columnwise
=
nullptr
;
cpu_data_rowwise_
=
nullptr
;
...
...
@@ -305,8 +312,8 @@ Tensor::Tensor(const std::string& name,
}
else
{
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
normalized_shape
,
tensor_
.
scaling_mode
());
auto
rowwise_scale_size
=
product
(
rowwise_scale_meta
.
shape
)
*
rowwise_scale_meta
.
type_size
;
auto
columnwise_scale_size
=
product
(
colwise_scale_meta
.
shape
)
*
colwise_scale_meta
.
type_size
;
auto
rowwise_scale_size
=
rowwise_scale_meta
.
bytes
()
;
auto
columnwise_scale_size
=
colwise_scale_meta
.
bytes
()
;
auto
scale_shape
=
rowwise_scale_meta
.
shape
;
auto
columnwise_scale_shape
=
colwise_scale_meta
.
shape
;
if
(
rowwise
)
{
...
...
@@ -331,7 +338,7 @@ Tensor::Tensor(const std::string& name,
void
Tensor
::
to_cpu
()
const
{
const
NVTEShape
s
=
tensor_
.
shape
();
const
size_t
size
=
product
(
s
)
*
typeToSize
(
tensor_
.
dtype
());
const
size_t
size
=
bytes
(
s
,
tensor_
.
dtype
());
if
(
rowwise_
)
{
cudaMemcpy
(
cpu_data_rowwise_
.
get
(),
tensor_
.
get_rowwise_data
().
data_ptr
,
...
...
@@ -360,14 +367,14 @@ void Tensor::to_cpu() const {
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
s
,
tensor_
.
scaling_mode
());
if
(
rowwise_
)
{
auto
scale_size
=
product
(
rowwise_scale_meta
.
shape
)
*
rowwise_scale_meta
.
type_size
;
auto
scale_size
=
rowwise_scale_meta
.
bytes
()
;
cudaMemcpy
(
rowwise_scale_inv_cpu_data_
.
get
(),
tensor_
.
get_rowwise_scale_inv
().
data_ptr
,
scale_size
,
cudaMemcpyDeviceToHost
);
}
if
(
columnwise_
)
{
auto
scale_size
=
product
(
colwise_scale_meta
.
shape
)
*
colwise_scale_meta
.
type_size
;
auto
scale_size
=
colwise_scale_meta
.
bytes
()
;
cudaMemcpy
(
columnwise_scale_inv_cpu_data_
.
get
(),
tensor_
.
get_columnwise_scale_inv
().
data_ptr
,
scale_size
,
...
...
@@ -378,34 +385,32 @@ void Tensor::to_cpu() const {
void
Tensor
::
from_cpu
()
const
{
const
NVTEShape
s
=
tensor_
.
shape
();
const
size_t
size
=
product
(
s
)
*
typeToSize
(
tensor_
.
dtype
());
const
size_t
size
=
bytes
(
s
,
tensor_
.
dtype
());
if
(
rowwise_
)
{
cudaMemcpy
(
tensor_
.
get_rowwise_data
().
data_ptr
,
cpu_data_rowwise_
.
get
(),
size
,
cudaMemcpyHostToDevice
);
cudaMemcpy
(
tensor_
.
get_rowwise_data
().
data_ptr
,
cpu_data_rowwise_
.
get
(),
size
,
cudaMemcpyHostToDevice
);
}
if
(
columnwise_
)
{
cudaMemcpy
(
tensor_
.
get_columnwise_data
().
data_ptr
,
cpu_data_columnwise_
.
get
(),
size
,
cudaMemcpyHostToDevice
);
cudaMemcpy
(
tensor_
.
get_columnwise_data
().
data_ptr
,
cpu_data_columnwise_
.
get
(),
size
,
cudaMemcpyHostToDevice
);
}
if
(
isFp8Type
(
dtype
()))
{
if
(
tensor_
.
scaling_mode
()
==
NVTE_DELAYED_TENSOR_SCALING
)
{
if
(
tensor_
.
amax
()
!=
nullptr
){
cudaMemcpy
(
tensor_
.
amax
(),
amax_cpu_data_
.
get
(),
sizeof
(
float
),
cudaMemcpyHostToDevice
);
cudaMemcpy
(
tensor_
.
amax
(),
amax_cpu_data_
.
get
(),
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
cudaMemcpy
(
tensor_
.
scale
(),
scale_cpu_data_
.
get
(),
sizeof
(
float
),
cudaMemcpyHostToDevice
);
cudaMemcpy
(
tensor_
.
scale
(),
scale_cpu_data_
.
get
(),
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
s
,
tensor_
.
scaling_mode
());
if
(
rowwise_
)
{
auto
scale_size
=
product
(
rowwise_scale_meta
.
shape
)
*
rowwise_scale_meta
.
type_size
;
auto
scale_size
=
rowwise_scale_meta
.
bytes
()
;
cudaMemcpy
(
tensor_
.
get_rowwise_scale_inv
().
data_ptr
,
rowwise_scale_inv_cpu_data_
.
get
(),
scale_size
,
cudaMemcpyHostToDevice
);
}
if
(
columnwise_
)
{
auto
scale_size
=
product
(
colwise_scale_meta
.
shape
)
*
colwise_scale_meta
.
type_size
;
auto
scale_size
=
colwise_scale_meta
.
bytes
()
;
cudaMemcpy
(
tensor_
.
get_columnwise_scale_inv
().
data_ptr
,
columnwise_scale_inv_cpu_data_
.
get
(),
scale_size
,
cudaMemcpyHostToDevice
);
...
...
@@ -735,6 +740,19 @@ std::pair<double, double> getTolerances(const DType type) {
template
<
typename
T
>
void
generate_data_uniformly
(
T
*
data
,
const
size_t
size
,
std
::
mt19937
*
gen
)
{
// Check how many RNG calls are required to generate one uniform random value
int
rng_calls_per_val
=
0
;
{
std
::
mt19937
gen1
=
*
gen
,
gen2
=
*
gen
;
std
::
uniform_real_distribution
<>
dis
(
-
2.0
,
1.0
);
const
float
_
=
dis
(
gen1
);
while
(
gen2
!=
gen1
)
{
auto
_
=
gen2
();
++
rng_calls_per_val
;
}
}
// Generate uniform random values in parallel
#pragma omp parallel proc_bind(spread)
{
std
::
mt19937
gen_local
=
*
gen
;
...
...
@@ -743,7 +761,7 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
const
int
chunk_size
=
(
size
+
threads_num
-
1
)
/
threads_num
;
const
int
idx_min
=
chunk_size
*
thread_ID
;
const
int
idx_max
=
std
::
min
(
chunk_size
*
(
thread_ID
+
1
),
static_cast
<
int
>
(
size
));
gen_local
.
discard
(
idx_min
);
gen_local
.
discard
(
idx_min
*
rng_calls_per_val
);
std
::
uniform_real_distribution
<>
dis
(
-
2.0
,
1.0
);
for
(
int
i
=
idx_min
;
i
<
idx_max
;
++
i
)
{
...
...
@@ -754,7 +772,7 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
#endif
}
}
gen
->
discard
(
size
);
gen
->
discard
(
size
*
rng_calls_per_val
);
}
void
fillUniform
(
Tensor
*
t
)
{
...
...
tests/cpp/test_common.h
View file @
2b05e121
...
...
@@ -10,11 +10,18 @@
#include <vector>
#include <array>
#include <random>
#include <cudaTypedefs.h>
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#include <cuda_runtime_api.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif
#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"
...
...
@@ -56,20 +63,32 @@ using fp8e4m3 = __nv_fp8_e4m3;
using
fp8e5m2
=
__nv_fp8_e5m2
;
using
fp8e8m0
=
uint8_t
;
using
int8
=
int8_t
;
#if FP4_TYPE_SUPPORTED
using
fp4e2m1
=
__nv_fp4_e2m1
;
#endif
template
<
typename
T
>
struct
BitsNumber
;
#if FP4_TYPE_SUPPORTED
template
<
>
struct
BitsNumber
<
fp4e2m1
>
{
static
constexpr
size_t
num_bits
=
4
;
};
#endif
template
<
typename
T
>
struct
BitsNumber
{
static
constexpr
size_t
num_bits
=
8
*
sizeof
(
T
);
};
template
<
typename
T
>
struct
TypeInfo
{
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp8e8m0
,
int8
>
;
struct
TypeInfo
{
#if FP4_TYPE_SUPPORTED
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp8e8m0
,
fp4e2m1
>
;
#else
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp8e8m0
,
int8
>
;
#endif
template
<
typename
U
,
DType
current
>
struct
Helper
{
...
...
@@ -96,7 +115,7 @@ struct TypeInfo{
}
constexpr
static
DType
dtype
=
getType
<
T
>
();
constexpr
static
size_t
size
=
sizeof
(
T
)
;
constexpr
static
size_t
size
=
BitsNumber
<
T
>::
num_bits
;
;
};
class
Tensor
{
...
...
@@ -418,9 +437,10 @@ inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); }
inline
float
srelu
(
const
float
x
)
{
return
x
>
0
?
x
*
x
:
0
;
}
inline
float
dsrelu
(
const
float
x
)
{
return
fmaxf
(
0
,
2
*
x
);
}
size_t
typeTo
Size
(
DType
type
);
size_t
typeTo
NumBits
(
DType
type
);
size_t
product
(
const
NVTEShape
&
shape
);
size_t
product
(
const
std
::
vector
<
size_t
>
&
shape
);
size_t
bytes
(
const
NVTEShape
&
shape
,
const
DType
type
);
size_t
first_dimension
(
const
std
::
vector
<
size_t
>
&
shape
);
size_t
last_dimension
(
const
std
::
vector
<
size_t
>
&
shape
);
...
...
@@ -466,6 +486,16 @@ constexpr int32_t blackwellComputeCapability = 100;
}
// namespace test
#if FP4_TYPE_SUPPORTED
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
case DType::kFloat4E2M1: { \
using type = fp4e2m1; \
{ __VA_ARGS__ } \
} break;
#else
#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing
#endif
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
...
...
@@ -517,8 +547,16 @@ constexpr int32_t blackwellComputeCapability = 100;
{__VA_ARGS__} \
} \
break; \
case DType::kFloat8E8M0: \
{ \
using type = fp8e8m0; \
{__VA_ARGS__} \
} \
break; \
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type."); \
printf("dtype: %d\n", static_cast<int>(dtype)); \
NVTE_ERROR("Invalid type MARKED TEST."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
...
...
@@ -537,7 +575,15 @@ constexpr int32_t blackwellComputeCapability = 100;
} \
break; \
default: \
NVTE_ERROR("Invalid type."); \
NVTE_ERROR("Invalid type MARKED TEST 2."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
SWITCH_FP4_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type MARKED TEST 3."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
...
...
@@ -562,5 +608,5 @@ constexpr int32_t blackwellComputeCapability = 100;
} \
break; \
default: \
NVTE_ERROR("Invalid type."); \
NVTE_ERROR("Invalid type
MARKED TEST 4
."); \
}
tests/jax/test_custom_call_compute.py
View file @
2b05e121
...
...
@@ -4,15 +4,14 @@
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
import
pytest
from
jax
import
jit
,
value_and_grad
from
functools
import
reduce
from
typing
import
Union
import
operator
from
utils
import
(
assert_allclose
,
assert_tree_like_allclose
,
pytest_parametrize_wrapper
,
)
from
transformer_engine.jax.layernorm
import
layernorm
...
...
@@ -33,15 +32,18 @@ from transformer_engine.jax import cpp_extensions as tex
from
transformer_engine.jax.quantize
import
(
DelayedScaleQuantizer
,
ScaledTensor
,
ScaledTensor1x
,
ScaledTensor2x
,
GroupedScaledTensor1x
,
ScalingMode
,
QuantizerFactory
,
QuantizeLayout
,
noop_quantizer_set
,
)
from
transformer_engine.jax.quantize
import
helper
from
transformer_engine.jax.activation
import
activation
from
transformer_engine.jax.dense
import
dense
from
transformer_engine.jax.dense
import
dense
,
grouped_dense
from
transformer_engine.jax.layernorm_dense
import
layernorm_dense
from
transformer_engine.jax.quantize
import
ScaledTensor1x
,
ScaledTensor2x
GEMM_CASES
=
[
(
256
,
256
,
512
),
...
...
@@ -53,8 +55,8 @@ GEMM_CASES = [
FP8_COMPUTE_TYPE
=
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
]
LN_CASES
=
[(
256
,
128
),
(
128
,
256
)]
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float32
]
is_fp8_supported
,
reason
=
helper
.
is_fp8_available
()
is_mxfp8_supported
,
reason
=
helper
.
is_fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
is_fp8_supported
,
fp8_unsupported_
reason
=
helper
.
is_fp8_available
()
is_mxfp8_supported
,
mxfp8_unsupported_
reason
=
helper
.
is_fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
supported_scaling_modes
=
[]
""" Find supported scaling modes"""
...
...
@@ -113,6 +115,38 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
pytest
.
fail
(
"a must be a ScaledTensor object"
)
def
assert_dequantized_grouped_scaled_tensor
(
a
:
Union
[
GroupedScaledTensor1x
,
ScaledTensor2x
],
b
:
jnp
.
ndarray
):
if
isinstance
(
a
,
GroupedScaledTensor1x
):
assert
a
.
group_sizes
.
sum
()
==
b
.
shape
[
0
]
b
=
jnp
.
split
(
b
,
jnp
.
cumulative_sum
(
a
.
group_sizes
)[:
-
1
],
axis
=
0
)
dq_a
=
a
.
dequantize
()
for
dq_a_i
,
b_i
in
zip
(
dq_a
,
b
):
if
len
(
dq_a_i
)
==
0
:
continue
if
a
.
data_layout
==
"T"
:
data_ndim
=
len
(
a
.
original_shape
)
flatten_axis
=
a
.
flatten_axis
if
b_i
.
shape
[
0
]
==
1
:
b_i
=
jnp
.
transpose
(
b_i
,
(
0
,
*
range
(
flatten_axis
,
data_ndim
),
*
range
(
1
,
flatten_axis
))
)
else
:
b_i
=
jnp
.
transpose
(
b_i
,
(
*
range
(
flatten_axis
,
data_ndim
),
*
range
(
flatten_axis
))
)
dq_a_i
=
dq_a_i
.
reshape
(
b_i
.
shape
)
assert_allclose
(
dq_a_i
,
b_i
,
dtype
=
a
.
data
.
dtype
)
elif
isinstance
(
a
,
ScaledTensor2x
):
assert
isinstance
(
a
.
get_rowwise_tensor
(),
GroupedScaledTensor1x
)
assert
isinstance
(
a
.
get_colwise_tensor
(),
GroupedScaledTensor1x
)
assert_dequantized_grouped_scaled_tensor
(
a
.
get_rowwise_tensor
(),
b
)
assert_dequantized_grouped_scaled_tensor
(
a
.
get_colwise_tensor
(),
b
)
else
:
pytest
.
fail
(
"a must be a GroupedScaledTensor object"
)
ALL_ACTIVATION_SHAPES
=
[(
32
,
64
),
(
16
,
128
,
256
)]
ALL_ACTIVATION_TYPES
=
[
(
"gelu"
,),
...
...
@@ -173,7 +207,7 @@ class TestActivation:
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
x
.
dtype
)
assert_allclose
(
prim_grad
,
ref_grad
,
dtype
=
x
.
dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_
reason
)
@
pytest_parametrize_wrapper
(
"shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
...
...
@@ -204,7 +238,7 @@ class TestActivation:
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
output_type
)
assert_allclose
(
prim_grad
,
ref_grad
,
dtype
=
output_type
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
mxfp8_unsupported_
reason
)
@
pytest_parametrize_wrapper
(
"shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
...
...
@@ -234,7 +268,7 @@ class TestActivation:
assert_bitwise_scaled_tensors
(
te_output
,
jax_output
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
mxfp8_unsupported_
reason
)
@
pytest_parametrize_wrapper
(
"shape"
,
[(
2
,
64
,
1
,
256
)])
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
...
...
@@ -355,7 +389,7 @@ class TestNorm:
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
quantizer
=
None
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_
reason
)
# No Norm FWD E5M2 in TE backend
@
pytest_parametrize_wrapper
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
])
@
pytest_parametrize_wrapper
(
...
...
@@ -470,7 +504,7 @@ class TestNorm:
if
norm_type
==
"layernorm"
:
assert_allclose
(
mu
,
ref_mu
,
dtype
=
inp_dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_
reason
)
# No Norm FWD E5M2 in TE backend
@
pytest_parametrize_wrapper
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
])
@
pytest_parametrize_wrapper
(
...
...
@@ -506,7 +540,7 @@ class TestNorm:
q_layout
=
q_layout
,
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
mxfp8_unsupported_
reason
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
def
test_norm_forward_with_block_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
...
...
@@ -532,7 +566,7 @@ QUANTIZE_OUTPUT_DTYPES = {
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES
=
[
((
32
,
64
),
-
1
),
((
2
,
64
,
32
),
-
1
),
((
2
,
64
,
32
),
-
2
),
((
64
,
2
,
32
),
-
2
),
((
32
,
256
,
128
),
-
1
),
((
32
,
256
,
128
),
-
2
),
((
64
,
32
,
32
,
256
),
-
1
),
...
...
@@ -544,7 +578,7 @@ QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
"L0"
:
[
((
32
,
64
),
-
1
),
((
2
,
64
,
32
),
-
1
),
((
2
,
64
,
32
),
-
2
),
((
64
,
2
,
32
),
-
2
),
],
"L2"
:
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES
,
}
...
...
@@ -555,7 +589,7 @@ QUANTIZATION_INPUT_DTYPE = {
}
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_
reason
)
@
pytest_parametrize_wrapper
(
"in_dtype"
,
QUANTIZATION_INPUT_DTYPE
)
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"input_shape,flatten_axis"
,
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES
)
...
...
@@ -577,9 +611,6 @@ class TestQuantize:
q_dtype
=
q_dtype
,
q_layout
=
q_layout
,
)
# Adding dimension to test if padding is done correctly when flatten 3D to 2D
if
flatten_axis
==
-
2
:
input_shape
=
input_shape
[:
-
1
]
+
(
2
,)
+
input_shape
[
-
1
:]
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
for
_
in
range
(
n_iterations
):
...
...
@@ -593,8 +624,6 @@ class TestQuantize:
):
key
=
jax
.
random
.
PRNGKey
(
0
)
if
flatten_axis
==
-
2
:
input_shape
=
input_shape
[:
-
1
]
+
(
2
,)
+
input_shape
[
-
1
:]
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
te_quantizer
,
jax_quantizer
=
QuantizerFactory
.
create
(
...
...
@@ -607,10 +636,65 @@ class TestQuantize:
assert_bitwise_scaled_tensors
(
te_output
,
jax_output
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest_parametrize_wrapper
(
"in_dtype"
,
QUANTIZATION_INPUT_DTYPE
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
[(
8
,
16
,
32
)])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
])
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"flatten_axis"
,
[
-
1
])
@
pytest_parametrize_wrapper
(
"with_group_sizes"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
,
QuantizeLayout
.
COLWISE
]
)
class
TestGroupedQuantize
:
def
test_grouped_qdq
(
self
,
in_dtype
,
input_shape
,
q_dtype
,
scaling_mode
,
q_layout
,
flatten_axis
,
with_group_sizes
):
n_groups
,
m
,
n
=
input_shape
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
# *32 so that the input shapes works for MXFP8
input_shape
=
(
m
*
32
,
n
)
if
with_group_sizes
:
group_sizes
=
jnp
.
sort
(
jax
.
random
.
randint
(
subkeys
[
0
],
(
n_groups
-
1
,),
0
,
m
))
group_sizes
=
jnp
.
concatenate
([
jnp
.
array
([
0
]),
group_sizes
,
jnp
.
array
([
m
])])
group_sizes
=
jnp
.
diff
(
group_sizes
)
assert
group_sizes
.
sum
()
==
m
assert
jnp
.
any
(
group_sizes
==
0
)
# make sure that at least one group has 0 row
group_sizes
=
group_sizes
*
32
else
:
group_sizes
=
None
input_shape
=
(
n_groups
,
input_shape
[
0
]
//
n_groups
,
input_shape
[
1
])
if
flatten_axis
==
-
2
:
input_shape
=
input_shape
[:
-
1
]
+
(
2
,)
+
input_shape
[
-
1
:]
x
=
jax
.
random
.
uniform
(
subkeys
[
1
],
input_shape
,
in_dtype
)
grouped_quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
scaling_mode
,
q_dtype
=
q_dtype
,
q_layout
=
q_layout
,
n_groups
=
n_groups
,
)
# grouped_quantize does not work with cudaGraph yet, so the jitting will breaks
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
# disable cudaGraph, then use the following jitted function
scaled_tensor
=
tex
.
grouped_quantize
(
x
,
group_sizes
=
group_sizes
,
flatten_axis
=
flatten_axis
,
quantizer
=
grouped_quantizer
)
assert_dequantized_grouped_scaled_tensor
(
scaled_tensor
,
x
)
@
pytest_parametrize_wrapper
(
"in_dtype"
,
QUANTIZATION_INPUT_DTYPE
)
class
TestFusedQuantize
:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_
reason
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"input_shape,flatten_axis"
,
QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
...
...
@@ -625,12 +709,6 @@ class TestFusedQuantize:
):
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by MXFP8"
)
if
(
flatten_axis
<
0
and
flatten_axis
+
len
(
input_shape
)
<=
0
)
or
flatten_axis
<=
0
:
pytest
.
skip
(
f
"Flatten axis
{
flatten_axis
}
is not supported for input shape
{
input_shape
}
. There"
" must be at least one axis on either side of the flatten_axis split."
)
key
=
jax
.
random
.
PRNGKey
(
0
)
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
...
...
@@ -717,7 +795,7 @@ class TestFusedQuantize:
q_layout
=
QuantizeLayout
.
ROWWISE
,
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_
reason
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
...
...
@@ -741,7 +819,7 @@ class TestFusedQuantize:
q_layout
=
q_layout
,
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
mxfp8_unsupported_
reason
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
[
s
for
s
in
ALL_ACTIVATION_SHAPES
if
is_shape_supported_by_mxfp8
(
s
)]
...
...
@@ -810,7 +888,7 @@ class TestDense:
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_
reason
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
...
...
@@ -852,7 +930,7 @@ class TestDense:
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_w_grad
,
ref_w_grad
,
dtype
=
jnp
.
bfloat16
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_
reason
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
...
...
@@ -916,7 +994,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class
TestFusedDense
:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
...
...
@@ -1001,7 +1079,7 @@ class TestFusedDense:
if
beta
is
not
None
:
assert_allclose
(
prim_beta_grad
,
ref_beta_grad
,
dtype
=
q_dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
...
...
@@ -1129,24 +1207,6 @@ class TestFusedDense:
assert_allclose
(
prim_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm()
def
_quantize_gemm_pair
(
lhs
,
rhs
,
contracting_dims
,
lhs_quantizer
,
rhs_quantizer
):
((
lhs_contract_dim
,),
(
rhs_contract_dim
,))
=
contracting_dims
lhs_is_rowwise
=
lhs_contract_dim
==
lhs
.
ndim
-
1
rhs_is_rowwise
=
rhs_contract_dim
==
rhs
.
ndim
-
1
lhs_q
=
lhs_quantizer
.
quantize
(
lhs
,
is_rowwise
=
lhs_is_rowwise
,
is_colwise
=
not
lhs_is_rowwise
,
)
rhs_q
=
rhs_quantizer
.
quantize
(
rhs
,
is_rowwise
=
rhs_is_rowwise
,
is_colwise
=
not
rhs_is_rowwise
,
)
return
lhs_q
,
rhs_q
# E5M2 * E5M2 is not supported
fwd_bwd_dtypes
=
[
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e4m3fn
],
...
...
@@ -1154,219 +1214,217 @@ fwd_bwd_dtypes = [
[
jnp
.
float8_e5m2
,
jnp
.
float8_e4m3fn
],
]
"""
@pytest_parametrize_wrapper(
"shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]]
)
GROUPED_DENSE_INPUT_SHAPES
=
[
# (n_groups, m, n, k), the actual m will be multiplied by 32
(
5
,
32
,
128
,
64
),
# Test the case where n_groups is not a multiple of 4
(
8
,
64
,
32
,
128
),
(
8
,
64
,
128
,
256
),
]
@
pytest_parametrize_wrapper
(
"input_shape"
,
GROUPED_DENSE_INPUT_SHAPES
)
class
TestGroupedDense
:
def _ref_grouped_gemm_with_jnp_dot(self, lhs_list, rhs_list, contracting_dims_list):
ref_out_list = []
for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list):
dim_nums = (contracting_dims, ((), ()))
ref_out_list.append(jax.lax.dot_general(lhs, rhs, dim_nums))
return ref_out_list
def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list):
def
_ref_grouped_dense
(
self
,
lhs
,
rhs
,
bias
,
group_sizes
,
contracting_dims
):
lhs_contract_dim
,
_
=
contracting_dims
assert
len
(
lhs_contract_dim
)
==
1
and
lhs
.
ndim
==
2
and
rhs
.
ndim
==
3
if
bias
is
None
:
bias
=
jnp
.
zeros
((
rhs
.
shape
[
0
],
rhs
.
shape
[
2
]),
dtype
=
lhs
.
dtype
)
else
:
assert
bias
.
ndim
==
2
and
bias
.
shape
==
(
rhs
.
shape
[
0
],
rhs
.
shape
[
2
])
remaining_axis
=
(
set
(
range
(
lhs
.
ndim
))
-
set
(
lhs_contract_dim
)).
pop
()
lhs
=
jnp
.
split
(
lhs
,
jnp
.
cumulative_sum
(
group_sizes
)[:
-
1
],
axis
=
remaining_axis
)
rhs
=
jnp
.
split
(
rhs
,
rhs
.
shape
[
0
],
axis
=
0
)
bias
=
jnp
.
split
(
bias
,
bias
.
shape
[
0
],
axis
=
0
)
ref_out
=
[]
dim_num
=
(
contracting_dims
,
((),
()))
for
lhs_i
,
rhs_i
,
bias_i
in
zip
(
lhs
,
rhs
,
bias
):
out_i
=
jax
.
lax
.
dot_general
(
lhs_i
,
rhs_i
,
dim_num
)
+
jnp
.
expand_dims
(
bias_i
,
axis
=
0
)
ref_out
.
append
(
jnp
.
squeeze
(
out_i
))
return
ref_out
def
_generate_grouped_dense_input
(
self
,
dtype
,
input_shape
,
data_layout
=
"NN"
,
with_bias
=
False
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys = jax.random.split(key, len(shape_list) * 2)
lhs_list, rhs_list, contracting_dims_list = [], [], []
for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)):
lhs = jax.random.uniform(
subkeys[2 * i],
(m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=dtype,
)
rhs = jax.random.uniform(
subkeys[2 * i + 1],
(k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=dtype,
)
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
subkeys
=
jax
.
random
.
split
(
key
,
4
)
n_groups
,
m
,
n
,
k
=
input_shape
group_sizes
=
jnp
.
sort
(
jax
.
random
.
randint
(
subkeys
[
0
],
(
n_groups
-
1
,),
0
,
m
))
group_sizes
=
jnp
.
concatenate
([
jnp
.
array
([
0
]),
group_sizes
,
jnp
.
array
([
m
])])
group_sizes
=
jnp
.
diff
(
group_sizes
)
assert
group_sizes
.
sum
()
==
m
# *32 to make sure that input shape works for MXFP8
group_sizes
=
group_sizes
*
32
m
=
m
*
32
lhs_shape
=
(
m
if
data_layout
[
0
]
==
"N"
else
k
,
k
if
data_layout
[
0
]
==
"N"
else
m
)
rhs_shape
=
(
n_groups
,
k
if
data_layout
[
1
]
==
"N"
else
n
,
n
if
data_layout
[
1
]
==
"N"
else
k
)
bias_shape
=
(
n_groups
,
n
)
lhs_list.append(lhs)
rhs_list.append(rhs)
contracting_dims_list.append(contracting_dims)
lhs
=
jax
.
random
.
uniform
(
subkeys
[
1
],
lhs_shape
,
dtype
=
dtype
)
rhs
=
jax
.
random
.
uniform
(
subkeys
[
2
],
rhs_shape
,
dtype
=
dtype
)
bias
=
jax
.
random
.
uniform
(
subkeys
[
3
],
bias_shape
,
dtype
=
dtype
)
if
with_bias
else
None
lhs_contracting_dim
=
(
1
,)
if
data_layout
[
0
]
==
"N"
else
(
0
,)
rhs_contracting_dim
=
(
1
,)
if
data_layout
[
1
]
==
"N"
else
(
2
,)
contracting_dims
=
(
lhs_contracting_dim
,
rhs_contracting_dim
)
return lhs_list, rhs_list, contracting_dims_list
return
lhs
,
rhs
,
group_sizes
,
contracting_dims
,
bias
def
_assert_grouped_gemm_output
(
self
,
out
,
group_sizes
,
ref_list
,
dtype
):
assert
out
.
dtype
==
ref_list
[
0
].
dtype
out_list
=
jnp
.
split
(
out
,
jnp
.
cumulative_sum
(
group_sizes
)[:
-
1
],
axis
=
0
)
for
i
in
range
(
len
(
ref_list
)):
assert_allclose
(
out_list
[
i
],
ref_list
[
i
],
dtype
=
dtype
)
@
pytest_parametrize_wrapper
(
"dtype"
,
[
jnp
.
bfloat16
,
jnp
.
float16
])
@pytest_parametrize_wrapper("layout
_list
",
[
["NN"
, "TN", "NT", "TT"]
])
def test_grouped_gemm_fp16(self, dtype, shape
_list
, layout
_list
):
lhs
_list, rhs_list
, contracting_dims
_list
= self._generate_grouped_
gemm
_input(
dtype, shape
_list
, layout
_list
@
pytest_parametrize_wrapper
(
"layout"
,
[
"NN"
])
def
test_grouped_gemm_fp16
(
self
,
dtype
,
input_
shape
,
layout
):
lhs
,
rhs
,
group_sizes
,
contracting_dims
,
_
=
self
.
_generate_grouped_
dense
_input
(
dtype
,
input_
shape
,
layout
)
ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
primitive_out = tex.grouped_gemm(lhs_list, rhs_list, contracting_dims_list)
for i in range(len(shape_list)):
assert_allclose(primitive_out[i], ref_out[i], dtype=dtype)
ref_out
=
self
.
_ref_grouped_dense
(
lhs
,
rhs
,
None
,
group_sizes
,
contracting_dims
)
# grouped_gemm does not work with cudaGraph yet, so the jitting will breaks
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
# disable cudaGraph, then use the following jitted function
# jitting grouped_gemm
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
# lhs, rhs, group_sizes, contracting_dims,
# )
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
prim_out
=
tex
.
grouped_gemm
(
lhs
,
rhs
,
group_sizes
,
contracting_dims
)
self
.
_assert_grouped_gemm_output
(
prim_out
,
group_sizes
,
ref_out
,
dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest
.
mark
.
parametrize
(
"fwd_bwd_dtype"
,
fwd_bwd_dtypes
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list, layout_list):
@
pytest_parametrize_wrapper
(
"layout"
,
[
"NN"
])
def
test_grouped_gemm_fp8
(
self
,
fwd_bwd_dtype
,
scaling_mode
,
input_shape
,
layout
):
if
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
pytest
.
skip
(
"MXFP8 is not supported in grouped_gemm yet"
)
fwd_dtype
,
bwd_dtype
=
fwd_bwd_dtype
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False
scaling_mode
=
scaling_mode
,
fwd_dtype
=
fwd_dtype
,
bwd_dtype
=
bwd_dtype
,
is_2x2x
=
False
,
n_groups
=
input_shape
[
0
],
)
# quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype
# We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype
quantizer_set
.
kernel
.
q_dtype
=
bwd_dtype
for
quantizer
in
quantizer_set
.
kernel
.
quantizers
:
quantizer
.
q_dtype
=
bwd_dtype
out_dtype
=
jnp
.
bfloat16
lhs
_list, rhs_list
, contracting_dims
_list
= self._generate_grouped_
gemm
_input(
out_dtype, shape
_list
, layout
_list
lhs
,
rhs
,
group_sizes
,
contracting_dims
,
_
=
self
.
_generate_grouped_
dense
_input
(
out_dtype
,
input_
shape
,
layout
)
q_lhs_list = []
q_rhs_list = []
for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list):
# quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to
# test the case where lhs and rhs have different q_dtypes
q_lhs, q_rhs = _quantize_gemm_pair(
lhs, rhs, contracting_dims, quantizer_set.x, quantizer_set.dgrad
)
q_lhs_list.append(q_lhs)
q_rhs_list.append(q_rhs)
ref_out
=
self
.
_ref_grouped_dense
(
lhs
,
rhs
,
None
,
group_sizes
,
contracting_dims
)
# jitting grouped_gemm
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))(
# lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
# )
ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
primitive_out = tex.grouped_gemm(q_lhs_list, q_rhs_list, contracting_dims_list)
prim_out
=
tex
.
grouped_gemm
(
lhs
,
rhs
,
group_sizes
,
contracting_dims
,
quantizer_set
=
quantizer_set
)
allclose_dtype
=
jnp
.
float8_e4m3fn
if
fwd_dtype ==
jnp.float8_e5m2
or
bwd_dtype
== jnp.float8_e5m2
:
if
jnp
.
float8_e5m2
in
fwd_
bwd_dtype
:
allclose_dtype
=
jnp
.
float8_e5m2
for i in range(len(shape_list)):
assert_allclose(primitive_out[i], ref_out[i], dtype=allclose_dtype)
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
def test_grouped_dense_grad_fp16(self, dtype, shape_list):
group_size = len(shape_list)
layout_list = ["NN" for _ in range(group_size)]
self
.
_assert_grouped_gemm_output
(
prim_out
,
group_sizes
,
ref_out
,
allclose_dtype
)
def
_ref_sum_grouped_dense
(
self
,
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
):
out_list
=
self
.
_ref_grouped_dense
(
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list
=
[
jnp
.
sum
(
out
)
for
out
in
out_list
]
return
jnp
.
sum
(
jnp
.
asarray
(
out_sum_list
))
x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input(
dtype, shape_list, layout_list
def
_primitive_sum_grouped_dense
(
self
,
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
,
quantizer_set
=
noop_quantizer_set
):
out
=
grouped_dense
(
x
,
kernel
,
group_sizes
,
contracting_dims
,
bias
=
bias
,
quantizer_set
=
quantizer_set
)
bias_list = []
key = jax.random.PRNGKey(1)
for shape in shape_list:
n = shape[1]
bias = jax.random.uniform(key, n, dtype=dtype)
bias_list.append(bias)
def ref_func(x_list, kernel_list, bias_list, contracting_dims_list):
out_list = []
for i in range(len(x_list)):
out_list.append(
dense(
x_list[i],
kernel_list[i],
bias_list[i],
contracting_dims=contracting_dims_list[i],
)
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
return
jnp
.
sum
(
jnp
.
asarray
(
out
))
def primitive_func(x_list, kernel_list, bias_list, contracting_dims_list):
out_list = grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list)
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
@
pytest_parametrize_wrapper
(
"dtype"
,
[
jnp
.
bfloat16
,
jnp
.
float16
])
def
test_grouped_dense_grad_fp16
(
self
,
dtype
,
input_shape
):
x
,
kernel
,
group_sizes
,
contracting_dims
,
bias
=
self
.
_generate_grouped_dense_input
(
dtype
,
input_shape
,
with_bias
=
True
,
)
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
value_n_grad_ref_func
=
value_and_grad
(
self
.
_ref_sum_grouped_dense
,
(
0
,
1
,
2
))
# jitting the grouped_dense
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
# static_argnums=(4,))
value_n_grad_prim_func
=
value_and_grad
(
self
.
_primitive_sum_grouped_dense
,
(
0
,
1
,
2
))
ref_out_
mean
, (ref_dgrad
_list
, ref_wgrad
_list
, ref_dbias
_list
) = value_n_grad_ref_func(
x
_list
, kernel
_list
, bias
_list
, contracting_dims
_list
ref_out_
sum
,
(
ref_dgrad
,
ref_wgrad
,
ref_dbias
)
=
value_n_grad_ref_func
(
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
)
prim
itive
_out_
mean
, (prim
itive
_dgrad
_list
, prim
itive
_wgrad
_list
, prim
itive
_dbias
_list) =
(
value_n_grad_primitive_func(x_list
, kernel
_list
, bias
_list
, contracting_dims
_list)
prim_out_
sum
,
(
prim_dgrad
,
prim_wgrad
,
prim_dbias
)
=
value_n_grad_prim_func
(
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
)
assert_allclose(primitive_out_mean, ref_out_mean, dtype=dtype)
for i in range(group_size):
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=dtype)
assert_allclose
(
prim_out_sum
,
ref_out_sum
,
dtype
=
dtype
)
assert_allclose
(
prim_dgrad
,
ref_dgrad
,
dtype
=
dtype
)
assert_allclose
(
prim_wgrad
,
ref_wgrad
,
dtype
=
dtype
)
assert_allclose
(
prim_dbias
,
ref_dbias
,
dtype
=
dtype
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest
.
mark
.
parametrize
(
"fwd_bwd_dtype"
,
[(
jnp
.
float8_e4m3fn
,
jnp
.
float8_e4m3fn
),
(
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
)],
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list):
group_size = len(shape_list)
layout_list = ["NN" for _ in range(group_size)]
fwd_dtype, bwd_dtype = fwd_bwd_dtype
if fwd_dtype == jnp.float8_e5m2:
pytest.skip("We never use E5M2 for fwd_dtype in training")
# Question: should we use different quantizers for different groups?
ref_quantizer_set_list = []
quantizer_set_list = []
for _ in range(group_size):
ref_quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
)
ref_quantizer_set_list.append(ref_quantizer_set)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
)
quantizer_set_list.append(quantizer_set)
def
test_grouped_dense_grad_fp8
(
self
,
fwd_bwd_dtype
,
scaling_mode
,
input_shape
):
if
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
pytest
.
skip
(
"MXFP8 is not supported in grouped_dense yet"
)
out_dtype = jnp.bfloat16
x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input(
out_dtype, shape_list, layout_list
fwd_dtype
,
bwd_dtype
=
fwd_bwd_dtype
dtype
=
jnp
.
bfloat16
x
,
kernel
,
group_sizes
,
contracting_dims
,
bias
=
self
.
_generate_grouped_dense_input
(
dtype
,
input_shape
,
with_bias
=
True
,
)
bias_list = []
key = jax.random.PRNGKey(1)
for shape in shape_list:
n = shape[1]
bias = jax.random.uniform(key, n, dtype=out_dtype)
bias_list.append(bias)
def ref_func(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
out_list = []
for i in range(len(x_list)):
out_list.append(
dense(
x_list[i],
kernel_list[i],
bias_list[i],
contracting_dims=contracting_dims_list[i],
quantizer_set=quantizer_set_list[i],
)
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
def primitive_func(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
):
out_list = grouped_dense(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
)
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func(
x_list, kernel_list, bias_list, contracting_dims_list, ref_quantizer_set_list
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
fwd_dtype
,
bwd_dtype
=
bwd_dtype
,
is_2x2x
=
True
,
n_groups
=
group_sizes
.
size
,
)
primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = (
value_n_grad_primitive_func(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
)
value_n_grad_ref_func
=
value_and_grad
(
self
.
_ref_sum_grouped_dense
,
(
0
,
1
,
2
))
# jitting the grouped_dense
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
# static_argnums=(4,))
value_n_grad_prim_func
=
value_and_grad
(
self
.
_primitive_sum_grouped_dense
,
(
0
,
1
,
2
))
ref_out_sum
,
(
ref_dgrad
,
ref_wgrad
,
ref_dbias
)
=
value_n_grad_ref_func
(
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
,
)
prim_out_sum
,
(
prim_dgrad
,
prim_wgrad
,
prim_dbias
)
=
value_n_grad_prim_func
(
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
,
quantizer_set
=
quantizer_set
)
allclose_dtype = jnp.float8_e4m3fn
if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2:
allclose_dtype = jnp.float8_e5m2
assert_allclose(primitive_out_mean, ref_out_mean, dtype=allclose_dtype)
for i in range(group_size):
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
"""
assert_allclose
(
prim_out_sum
,
ref_out_sum
,
dtype
=
fwd_dtype
)
assert_allclose
(
prim_dgrad
,
ref_dgrad
,
dtype
=
bwd_dtype
)
assert_allclose
(
prim_wgrad
,
ref_wgrad
,
dtype
=
bwd_dtype
)
assert_allclose
(
prim_dbias
,
ref_dbias
,
dtype
=
dtype
)
tests/jax/test_distributed_fused_attn.py
View file @
2b05e121
...
...
@@ -68,6 +68,7 @@ class TestDistributedSelfAttn:
batch
,
seqlen
,
num_head
,
hidden
=
data_shape
if
not
is_fused_attn_kernel_available
(
is_training
,
dtype
,
dtype
,
QKVLayout
.
BS3HD
,
...
...
@@ -79,6 +80,7 @@ class TestDistributedSelfAttn:
seqlen
,
seqlen
,
hidden
,
hidden
,
None
,
# no window
):
pytest
.
skip
(
"No FusedAttn backend found"
)
...
...
@@ -98,6 +100,7 @@ class TestDistributedSelfAttn:
num_head
,
num_head
,
hidden
,
hidden
,
attn_bias_type
,
attn_mask_type
,
dropout_prob
,
...
...
@@ -214,6 +217,7 @@ class TestDistributedCrossAttn:
batch
,
seqlen
,
num_head
,
hidden
=
data_shape
if
not
is_fused_attn_kernel_available
(
is_training
,
dtype
,
dtype
,
QKVLayout
.
BSHD_BS2HD
,
...
...
@@ -225,6 +229,7 @@ class TestDistributedCrossAttn:
seqlen
,
seqlen
,
hidden
,
hidden
,
None
,
# no window
):
pytest
.
skip
(
"No FusedAttn backend found"
)
...
...
@@ -237,6 +242,7 @@ class TestDistributedCrossAttn:
num_head
,
num_head
,
hidden
,
hidden
,
attn_bias_type
,
attn_mask_type
,
dropout_prob
,
...
...
@@ -289,6 +295,7 @@ class TestDistributedContextParallelSelfAttn:
cp_strategy
,
use_shardy
,
use_scan_ring
=
False
,
window_size
=
None
,
):
if
qkv_layout
.
is_thd
():
if
cp_strategy
==
CPStrategy
.
ALL_GATHER
:
...
...
@@ -326,6 +333,7 @@ class TestDistributedContextParallelSelfAttn:
num_head
,
num_kv_heads
,
hidden
,
hidden
,
attn_bias_type
,
attn_mask_type
,
dropout_prob
,
...
...
@@ -333,7 +341,7 @@ class TestDistributedContextParallelSelfAttn:
is_training
,
qkv_layout
,
bias_shape
,
Non
e
,
window_siz
e
,
SeqDescFormat
.
SegmentIDs
,
number_of_devices
=
device_count
,
mesh_shape
=
mesh_shape
,
...
...
@@ -345,6 +353,7 @@ class TestDistributedContextParallelSelfAttn:
def
check_has_backend_for_mask
(
mask_type
):
return
is_fused_attn_kernel_available
(
is_training
,
dtype
,
dtype
,
qkv_layout
,
...
...
@@ -356,6 +365,7 @@ class TestDistributedContextParallelSelfAttn:
seqlen
,
seqlen
,
hidden
,
hidden
,
None
,
)
# no SWA for CP
...
...
@@ -476,6 +486,13 @@ class TestDistributedContextParallelSelfAttn:
"use_scan"
,
[
pytest
.
param
(
False
,
id
=
"NO_SCAN"
),
pytest
.
param
(
True
,
id
=
"USE_SCAN"
)],
)
@
pytest
.
mark
.
parametrize
(
"window_size"
,
[
pytest
.
param
((
-
1
,
-
1
),
id
=
"window_size(-1, -1)"
),
pytest
.
param
((
20
,
0
),
id
=
"window_size(20, 0)"
),
],
)
def
test_context_parallel_ring_attn
(
self
,
device_count
,
...
...
@@ -489,7 +506,15 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout
,
load_balanced
,
use_scan
,
window_size
,
):
if
window_size
!=
(
-
1
,
-
1
)
and
not
qkv_layout
.
is_thd
():
pytest
.
skip
(
"Sliding window attention is only supported for THD layout"
)
if
window_size
!=
(
-
1
,
-
1
)
and
qkv_layout
.
is_thd
()
and
use_scan
:
pytest
.
skip
(
"When context parallelism and sliding window attention are used, "
"scanloop is not supported"
)
self
.
impl_test_context_parallel_attn
(
device_count
,
mesh_shape
,
...
...
@@ -504,6 +529,7 @@ class TestDistributedContextParallelSelfAttn:
CPStrategy
.
RING
,
use_shardy
=
False
,
use_scan_ring
=
use_scan
,
window_size
=
window_size
,
)
@
pytest
.
mark
.
parametrize
(
...
...
tests/jax/test_fused_attn.py
View file @
2b05e121
...
...
@@ -106,7 +106,8 @@ def general_dot_product_attention(
softmax_out
=
softmax_out
*
multiplier
context
=
jnp
.
einsum
(
"...hgqk,...khd->...qhgd"
,
softmax_out
,
value
)
context
=
jnp
.
reshape
(
context
,
query
.
shape
)
context_shape
=
query
.
shape
[:
-
1
]
+
(
value
.
shape
[
-
1
],)
context
=
jnp
.
reshape
(
context
,
context_shape
)
return
context
...
...
@@ -294,7 +295,8 @@ class FusedAttnRunner:
max_seqlen_kv
:
int
num_heads_q
:
int
num_heads_kv
:
int
head_dim
:
int
head_dim_qk
:
int
head_dim_v
:
int
attn_bias_type
:
AttnBiasType
attn_mask_type
:
AttnMaskType
dropout_prob
:
float
...
...
@@ -346,7 +348,16 @@ class FusedAttnRunner:
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
if
self
.
head_dim_qk
!=
self
.
head_dim_v
and
not
self
.
qkv_layout
.
is_separate
():
pytest
.
skip
(
"For head_dim_qk != head_dim_v, it is necessary that the QKV layout "
"is either BSHD_BSHD_BSHD or THD_THD_THD"
)
self
.
backend
=
FusedAttnHelper
(
self
.
is_training
,
self
.
dtype
,
self
.
dtype
,
self
.
qkv_layout
,
...
...
@@ -357,7 +368,8 @@ class FusedAttnRunner:
self
.
num_heads_kv
,
self
.
max_seqlen_q
,
self
.
max_seqlen_kv
,
self
.
head_dim
,
self
.
head_dim_qk
,
self
.
head_dim_v
,
(
-
1
,
-
1
)
if
self
.
window_size
is
None
else
self
.
window_size
,
).
get_fused_attn_backend
()
if
self
.
backend
==
NVTE_Fused_Attn_Backend
.
NVTE_No_Backend
:
...
...
@@ -390,13 +402,9 @@ class FusedAttnRunner:
key
=
jax
.
random
.
PRNGKey
(
0
)
q_key
,
k_key
,
v_key
,
bias_key
,
dropout_key
=
jax
.
random
.
split
(
key
,
5
)
q_shape
=
(
self
.
batch_size
,
self
.
max_seqlen_q
,
self
.
num_heads_q
,
self
.
head_dim
)
k_shape
=
v_shape
=
(
self
.
batch_size
,
self
.
max_seqlen_kv
,
self
.
num_heads_kv
,
self
.
head_dim
,
)
q_shape
=
(
self
.
batch_size
,
self
.
max_seqlen_q
,
self
.
num_heads_q
,
self
.
head_dim_qk
)
k_shape
=
(
self
.
batch_size
,
self
.
max_seqlen_kv
,
self
.
num_heads_kv
,
self
.
head_dim_qk
)
v_shape
=
(
self
.
batch_size
,
self
.
max_seqlen_kv
,
self
.
num_heads_kv
,
self
.
head_dim_v
)
if
self
.
attn_bias_type
==
AttnBiasType
.
NO_BIAS
:
bias_shape
=
None
...
...
@@ -615,7 +623,7 @@ class FusedAttnRunner:
raise
ValueError
(
f
"Unknown
{
self
.
seq_desc_format
=
}
"
)
self
.
dropout_rng
=
dropout_key
if
self
.
dropout_prob
>
0
else
None
self
.
scaling_factor
=
1.0
/
sqrt
(
self
.
head_dim
)
self
.
scaling_factor
=
1.0
/
sqrt
(
self
.
head_dim
_qk
)
# Setup distributed sharding specs
# Setup shardings for distributed tests
...
...
@@ -934,9 +942,31 @@ class FusedAttnRunner:
],
)
@
pytest
.
mark
.
parametrize
(
"b, s_q, s_kv, h_q, h_kv, d, dtype"
,
"b, s_q, s_kv, h_q, h_kv, d
_qk, d_v
, dtype"
,
[
pytest
.
param
(
2
,
2048
,
2048
,
12
,
12
,
64
,
jnp
.
bfloat16
,
id
=
"2-2048-2048-12-12-64-BF16-SELF"
),
pytest
.
param
(
2
,
2048
,
2048
,
12
,
12
,
64
,
64
,
jnp
.
bfloat16
,
id
=
"2-2048-2048-12-12-64-64-BF16-SELF"
),
pytest
.
param
(
2
,
2048
,
1024
,
12
,
12
,
64
,
64
,
jnp
.
bfloat16
,
id
=
"2-2048-1024-12-12-64-64-BF16-CROSS"
,
),
pytest
.
param
(
2
,
2048
,
2048
,
12
,
6
,
64
,
64
,
jnp
.
bfloat16
,
id
=
"2-2048-2048-12-6-64-64-BF16-GQA"
),
pytest
.
param
(
4
,
128
,
128
,
16
,
16
,
64
,
64
,
jnp
.
float16
,
id
=
"4-128-128-16-16-64-64-FP16-SELF"
),
pytest
.
param
(
4
,
128
,
128
,
16
,
16
,
64
,
32
,
jnp
.
float16
,
id
=
"4-128-128-16-16-64-32-FP16-SELF"
),
pytest
.
param
(
2
,
2048
,
...
...
@@ -944,11 +974,13 @@ class FusedAttnRunner:
12
,
12
,
64
,
32
,
jnp
.
bfloat16
,
id
=
"2-2048-1024-12-12-64-BF16-CROSS"
,
id
=
"2-2048-1024-12-12-64-32-BF16-CROSS"
,
),
pytest
.
param
(
2
,
2048
,
2048
,
12
,
6
,
128
,
64
,
jnp
.
float16
,
id
=
"2-2048-2048-12-6-128-64-FP16-GQA"
),
pytest
.
param
(
2
,
2048
,
2048
,
12
,
6
,
64
,
jnp
.
bfloat16
,
id
=
"2-2048-2048-12-6-64-BF16-GQA"
),
pytest
.
param
(
4
,
128
,
128
,
16
,
16
,
64
,
jnp
.
float16
,
id
=
"4-128-128-16-16-64-FP16-SELF"
),
],
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -1002,7 +1034,8 @@ class TestFusedAttn:
s_kv
,
h_q
,
h_kv
,
d
,
d_qk
,
d_v
,
attn_bias_type
,
attn_mask_type
,
dropout_prob
,
...
...
@@ -1027,7 +1060,8 @@ class TestFusedAttn:
s_kv
,
h_q
,
h_kv
,
d
,
d_qk
,
d_v
,
attn_bias_type
,
attn_mask_type
,
dropout_prob
,
...
...
@@ -1054,7 +1088,8 @@ class TestFusedAttn:
s_kv
,
h_q
,
h_kv
,
d
,
d_qk
,
d_v
,
attn_bias_type
,
attn_mask_type
,
dropout_prob
,
...
...
@@ -1076,7 +1111,8 @@ class TestFusedAttn:
s_kv
,
h_q
,
h_kv
,
d
,
d_qk
,
d_v
,
attn_bias_type
,
attn_mask_type
,
dropout_prob
,
...
...
tests/pytorch/debug/conftest.py
0 → 100644
View file @
2b05e121
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
def
pytest_addoption
(
parser
):
parser
.
addoption
(
"--feature_dirs"
,
nargs
=
"+"
,
action
=
"store"
,
default
=
""
,
help
=
"List of feature directories"
)
parser
.
addoption
(
"--configs_dir"
,
action
=
"store"
,
default
=
""
,
type
=
str
,
help
=
"Path to the directory with configs."
,
)
@
pytest
.
fixture
def
feature_dirs
(
request
):
return
request
.
config
.
getoption
(
"--feature_dirs"
)
@
pytest
.
fixture
def
configs_dir
(
request
):
return
request
.
config
.
getoption
(
"--configs_dir"
)
tests/pytorch/debug/run_distributed.py
0 → 100644
View file @
2b05e121
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
tempfile
import
functools
import
os
import
itertools
import
random
import
argparse
import
re
import
torch
import
torch.distributed
as
dist
import
transformer_engine
import
transformer_engine_torch
as
tex
import
nvdlfw_inspect.api
as
debug_api
from
transformer_engine.debug
import
set_weight_tensor_tp_group_reduce
from
test_numerics
import
(
_emulate_linear
,
_init_debug
,
disable_fp8_gemms_create_config
,
DISABLE_FP8_LAYER_CONFIG
,
_cmp
,
IN_SIZE
,
OUT_SIZE
,
_init_model
,
SEED
,
SEQ_LEN
,
BATCH_SIZE
,
FP8_RECIPE
,
fake_quant_fp8_create_config
,
_get_current_scale
,
_prepare_per_tensor_scaling_config
,
AMAX_HISTORY_LEN
,
set_scaling_factors
,
set_current_scaling_factors
,
)
WORLD_RANK
,
WORLD_SIZE
=
None
,
None
NCCL_WORLD
=
None
FEATURE_DIRS
=
None
all_boolean
=
[
True
,
False
]
TEST_NR
=
0
def
_get_tensors
(
parallel_mode
,
weight_seed
=
SEED
,
data_seed
=
SEED
,
tp_size
=
None
,
tp_rank
=
None
):
if
tp_size
is
None
:
tp_size
=
WORLD_SIZE
tp_rank
=
WORLD_RANK
torch
.
manual_seed
(
weight_seed
)
weight
=
torch
.
randn
((
OUT_SIZE
,
IN_SIZE
)).
cuda
()
torch
.
manual_seed
(
data_seed
)
in_split_size
=
IN_SIZE
//
tp_size
out_split_size
=
OUT_SIZE
//
tp_size
x
=
torch
.
randn
((
SEQ_LEN
*
BATCH_SIZE
,
IN_SIZE
),
requires_grad
=
True
).
cuda
()
if
parallel_mode
==
"row"
:
x
=
x
[:,
tp_rank
*
in_split_size
:
(
tp_rank
+
1
)
*
in_split_size
]
x
.
retain_grad
()
with
torch
.
no_grad
():
if
parallel_mode
==
"column"
:
weight
=
weight
[
tp_rank
*
out_split_size
:
(
tp_rank
+
1
)
*
out_split_size
,
:]
else
:
weight
=
weight
[:,
tp_rank
*
in_split_size
:
(
tp_rank
+
1
)
*
in_split_size
]
return
x
,
weight
.
contiguous
()
def
_init_model
(
weight
,
parallel_mode
=
None
,
tp_group
=
None
,
name
=
"linear"
):
model
=
transformer_engine
.
pytorch
.
Linear
(
IN_SIZE
,
OUT_SIZE
,
name
=
name
,
parallel_mode
=
parallel_mode
,
tp_group
=
(
tp_group
or
NCCL_WORLD
if
parallel_mode
else
None
),
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
weight
)
return
model
class
AllGather
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
tensor
,
dim
,
group
=
None
):
if
group
is
None
:
world_size
=
torch
.
distributed
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
else
:
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
rank
=
torch
.
distributed
.
get_rank
(
group
=
group
)
dist
.
barrier
()
# Create a list to gather tensors from all processes
y_list
=
[
torch
.
zeros_like
(
tensor
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
y_list
,
tensor
,
group
=
group
)
# Save the world size and rank for backward computation
ctx
.
world_size
=
world_size
ctx
.
rank
=
rank
ctx
.
dim
=
dim
# Concatenate the gathered tensors along the feature dimension
y_full
=
torch
.
cat
(
y_list
,
dim
=
dim
)
return
y_full
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# Split the gradient output and return the portion corresponding to this rank
grad_input
=
torch
.
chunk
(
grad_output
,
ctx
.
world_size
,
dim
=
ctx
.
dim
)[
ctx
.
rank
]
return
grad_input
,
None
,
None
def
_run_forward_backward
(
x
,
model
,
parallel_mode
=
None
,
group
=
None
):
with
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
FP8_RECIPE
):
y
=
model
(
x
)
y
.
requires_grad_
(
True
)
y
.
retain_grad
()
if
parallel_mode
==
"column"
:
y
=
AllGather
.
apply
(
y
,
-
1
,
group
)
y
.
requires_grad_
(
True
)
y
.
retain_grad
()
l
=
y
.
sum
()
l
.
backward
()
elif
parallel_mode
==
"row"
:
l
=
y
.
sum
()
l
.
backward
()
debug_api
.
step
()
return
y
def
_emulate_linear_distributed
(
*
args
,
parallel_mode
=
None
,
**
kwargs
):
assert
parallel_mode
in
[
"column"
,
"row"
]
def
split
(
gradient
):
split_size
=
OUT_SIZE
//
WORLD_SIZE
gradient
=
gradient
[:,
WORLD_RANK
*
split_size
:
(
WORLD_RANK
+
1
)
*
split_size
]
return
gradient
activation_sync
=
None
gradient_sync
=
None
if
parallel_mode
==
"column"
:
activation_sync
=
lambda
x
:
AllGather
.
apply
(
x
,
-
1
)
gradient_sync
=
split
else
:
activation_sync
=
(
lambda
activation
:
dist
.
all_reduce
(
activation
,
op
=
dist
.
ReduceOp
.
SUM
)
or
activation
)
output
=
_emulate_linear
(
*
args
,
activation_sync
=
activation_sync
,
gradient_sync
=
gradient_sync
,
**
kwargs
)
if
parallel_mode
==
"column"
:
dist
.
all_reduce
(
output
[
"dgrad"
],
op
=
dist
.
ReduceOp
.
SUM
)
return
output
def
check_debug_log
(
msg
):
with
open
(
f
"log/debug_logs/debug_log_globalrank-
{
WORLD_RANK
}
.log"
,
"r"
)
as
f
:
for
line
in
f
.
readlines
():
if
msg
in
line
:
return
True
return
False
def
run_debug_test
(
func
):
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
rank
=
dist
.
get_rank
()
temp_file_name
=
None
temp_logdir_name
=
None
if
rank
==
0
:
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w+"
,
delete
=
False
)
as
temp_file
:
temp_file_name
=
temp_file
.
name
temp_dir_obj
=
tempfile
.
TemporaryDirectory
()
temp_logdir_name
=
temp_dir_obj
.
name
# Store the TemporaryDirectory object to prevent it from being deleted
wrapper
.
temp_dir_obj
=
temp_dir_obj
temp_file_name_list
=
[
temp_file_name
]
temp_logdir_name_list
=
[
temp_logdir_name
]
# Broadcast the temporary file and directory names to all processes
dist
.
broadcast_object_list
(
temp_file_name_list
,
src
=
0
)
dist
.
broadcast_object_list
(
temp_logdir_name_list
,
src
=
0
)
temp_file_name
=
temp_file_name_list
[
0
]
temp_logdir_name
=
temp_logdir_name_list
[
0
]
dist
.
barrier
()
config_file
=
open
(
temp_file_name
,
mode
=
"r+"
,
buffering
=
1
)
try
:
kwargs
[
"config_file"
]
=
config_file
kwargs
[
"log_dir"
]
=
temp_logdir_name
if
rank
==
0
:
global
TEST_NR
print
(
f
"Running test
{
TEST_NR
}
{
func
.
__name__
}
with args =
{
args
}
."
)
TEST_NR
+=
1
func
(
*
args
,
**
kwargs
)
finally
:
if
rank
==
0
and
temp_file_name
is
not
None
:
os
.
unlink
(
temp_file_name
)
debug_api
.
end_debug
()
if
rank
==
0
and
hasattr
(
wrapper
,
"temp_dir_obj"
):
wrapper
.
temp_dir_obj
.
cleanup
()
return
wrapper
CONFIG_LOG_TEST_DISTRIBUTED
=
"""log_distributed:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
LogFp8TensorStats:
enabled: True
tensors: [activation, gradient, weight]
stats: [underflows%]
start_step : 0
end_step: 1
"""
def
_prepare_config_test_log_distributed
(
config_file
):
if
WORLD_RANK
!=
0
:
return
config_file
.
write
(
CONFIG_LOG_TEST_DISTRIBUTED
)
config_file
.
flush
()
def
_compute_dynamic_range
(
tensor
):
tensor_abs
=
tensor
.
abs
()
tensor_abs
=
tensor_abs
[
tensor_abs
!=
0
]
if
tensor_abs
.
any
():
amin
=
tensor_abs
.
min
().
float
()
else
:
amin
=
torch
.
tensor
(
1
,
device
=
tensor
.
device
).
to
(
torch
.
float
)
amax
=
tensor_abs
.
max
().
float
()
if
not
amax
.
all
():
amax
=
torch
.
tensor
(
1
,
device
=
tensor
.
device
).
to
(
torch
.
float
)
dynamic_range
=
torch
.
log2
(
amax
)
-
torch
.
log2
(
amin
)
return
dynamic_range
@
run_debug_test
def
test_log_distributed
(
parallel_mode
,
gather_weight
,
**
kwargs
):
_prepare_config_test_log_distributed
(
kwargs
[
"config_file"
])
_init_debug
(
kwargs
[
"config_file"
].
name
,
kwargs
[
"log_dir"
],
FEATURE_DIRS
)
set_weight_tensor_tp_group_reduce
(
gather_weight
)
if
WORLD_SIZE
%
2
!=
0
:
return
# skip
TP_SIZE
=
WORLD_SIZE
//
2
DP_SIZE
=
2
TP_RANK
=
WORLD_RANK
%
TP_SIZE
DP_RANK
=
(
WORLD_RANK
-
TP_RANK
)
//
TP_SIZE
debug_api
.
set_tensor_reduction_group
(
NCCL_WORLD
)
x
,
weight
=
_get_tensors
(
parallel_mode
,
weight_seed
=
TP_RANK
*
1234
,
data_seed
=
DP_RANK
*
1234
,
tp_size
=
TP_SIZE
,
tp_rank
=
TP_RANK
,
)
tp_group_ranks
=
[
i
for
i
in
range
(
DP_RANK
*
TP_SIZE
,
(
DP_RANK
+
1
)
*
TP_SIZE
)]
tp_group
=
dist
.
new_group
(
ranks
=
tp_group_ranks
)
dp_group_ranks
=
[
i
for
i
in
range
(
TP_RANK
,
WORLD_SIZE
,
TP_SIZE
)]
dp_group
=
dist
.
new_group
(
ranks
=
dp_group_ranks
)
model
=
_init_model
(
weight
,
parallel_mode
=
parallel_mode
,
tp_group
=
tp_group
)
output
=
_run_forward_backward
(
x
,
model
,
parallel_mode
=
parallel_mode
,
group
=
tp_group
)
gathered_activation
=
AllGather
.
apply
(
x
.
contiguous
(),
0
)
gathered_weight
=
AllGather
.
apply
(
weight
.
contiguous
(),
0
,
tp_group
)
gathered_gradient
=
AllGather
.
apply
(
output
.
grad
.
contiguous
(),
0
,
dp_group
)
if
parallel_mode
==
"row"
:
gathered_gradient
=
AllGather
.
apply
(
gathered_gradient
,
0
,
tp_group
)
log_file
=
kwargs
[
"log_dir"
]
+
"/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log"
dist
.
barrier
()
if
WORLD_RANK
!=
0
:
return
# stats are gathered on node 0
with
open
(
log_file
)
as
f
:
content
=
f
.
read
()
def
get_stat
(
tensor
,
stat
):
regex
=
r
".*_{tensor}_{stat}\s+.*iteration=(\d+)\s+.*value=([-+]?\d*\.?\d+)"
.
format
(
tensor
=
tensor
,
stat
=
stat
)
for
line
in
content
.
splitlines
():
match
=
re
.
search
(
regex
,
line
)
if
match
:
value
=
float
(
match
.
group
(
2
))
return
value
rf
=
lambda
x
:
round
(
float
(
x
),
4
)
stats
=
[]
tensors
=
{
"activation"
:
gathered_activation
,
"weight"
:
gathered_weight
if
gather_weight
else
weight
,
"gradient"
:
gathered_gradient
,
}
stats
=
{
"min"
:
torch
.
min
,
"max"
:
torch
.
max
,
"mean"
:
torch
.
mean
,
"std"
:
torch
.
std
,
"l1_norm"
:
lambda
x
:
torch
.
norm
(
x
,
p
=
1
),
"l2_norm"
:
lambda
x
:
torch
.
norm
(
x
,
p
=
2
),
"cur_amax"
:
lambda
x
:
x
.
abs
().
max
(),
"dynamic_range"
:
_compute_dynamic_range
,
}
for
stat_key
in
stats
.
keys
():
for
tensor_key
in
tensors
.
keys
():
torch
.
testing
.
assert_close
(
get_stat
(
tensor_key
,
stat_key
),
rf
(
stats
[
stat_key
](
tensors
[
tensor_key
])),
atol
=
0.0001
,
rtol
=
0.0001
,
)
set_weight_tensor_tp_group_reduce
(
True
)
# reset
@
run_debug_test
def
test_log_expert_parallel
(
**
kwargs
):
"""
This test tests the scenario, when one of the node of data parallel does not invoke the debug layer.
It naturally occurs in the expert parallelism, when one expert doesn't get input on one node,
but gets it on other nodes. If there were all_gather inside forward(), this would result in deadlock.
"""
_prepare_config_test_log_distributed
(
kwargs
[
"config_file"
])
_init_debug
(
kwargs
[
"config_file"
].
name
,
kwargs
[
"log_dir"
],
FEATURE_DIRS
)
debug_api
.
set_tensor_reduction_group
(
NCCL_WORLD
)
x
,
weight
=
_get_tensors
(
"row"
,
weight_seed
=
WORLD_RANK
*
1234
,
data_seed
=
WORLD_RANK
*
1234
,
tp_size
=
1
,
tp_rank
=
0
)
# data parallel
model
=
_init_model
(
weight
,
parallel_mode
=
None
,
name
=
"linear1"
)
model1
=
_init_model
(
weight
,
parallel_mode
=
None
,
name
=
"linear2"
)
with
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
FP8_RECIPE
):
y1
=
model
(
x
)
y2
=
model1
(
x
)
y
=
y1
+
y2
y
.
sum
().
backward
()
debug_api
.
step
()
with
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
FP8_RECIPE
):
y
=
model
(
x
)
if
WORLD_RANK
!=
0
:
y
=
y
+
model1
(
x
)
y
.
sum
().
backward
()
@
run_debug_test
def
test_disable_fp8_gemms
(
fprop_fp8
,
dgrad_fp8
,
wgrad_fp8
,
parallel_mode
,
**
kwargs
):
disable_fp8_gemms_create_config
(
fprop_fp8
,
dgrad_fp8
,
wgrad_fp8
,
kwargs
[
"config_file"
])
fp8_kwargs
=
{
"fprop_fp8"
:
fprop_fp8
,
"dgrad_fp8"
:
dgrad_fp8
,
"wgrad_fp8"
:
wgrad_fp8
,
}
_init_debug
(
kwargs
[
"config_file"
].
name
,
kwargs
[
"log_dir"
],
FEATURE_DIRS
)
x
,
weight
=
_get_tensors
(
parallel_mode
)
model
=
_init_model
(
weight
,
parallel_mode
=
parallel_mode
)
y
=
_run_forward_backward
(
x
,
model
,
parallel_mode
=
parallel_mode
)
output
=
{
"activation"
:
y
.
clone
(),
"wgrad"
:
model
.
weight
.
grad
.
clone
(),
"dgrad"
:
x
.
grad
.
clone
()}
x
.
grad
.
zero_
()
ground_truth
=
_emulate_linear_distributed
(
x
,
weight
,
parallel_mode
=
parallel_mode
,
**
fp8_kwargs
)
_cmp
(
ground_truth
,
output
)
@
run_debug_test
def
test_disable_fp8_layer
(
parallel_mode
,
**
kwargs
):
if
WORLD_RANK
==
0
:
kwargs
[
"config_file"
].
write
(
DISABLE_FP8_LAYER_CONFIG
)
kwargs
[
"config_file"
].
flush
()
dist
.
barrier
()
x
,
weight
=
_get_tensors
(
parallel_mode
)
ground_truth
=
_emulate_linear_distributed
(
x
,
weight
,
parallel_mode
=
parallel_mode
)
x
.
grad
.
zero_
()
_init_debug
(
kwargs
[
"config_file"
].
name
,
kwargs
[
"log_dir"
],
FEATURE_DIRS
)
model
=
_init_model
(
weight
,
parallel_mode
)
y
=
_run_forward_backward
(
x
,
model
,
parallel_mode
)
output
=
{
"activation"
:
y
.
clone
(),
"wgrad"
:
model
.
weight
.
grad
.
clone
(),
"dgrad"
:
x
.
grad
.
clone
()}
_cmp
(
ground_truth
,
output
)
@
run_debug_test
def
test_per_tensor_scaling
(
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
,
parallel_mode
,
**
kwargs
,
):
input_kwargs
=
{
"fprop_inp"
:
fprop_inp
,
"fprop_weight"
:
fprop_weight
,
"dgrad_weight"
:
dgrad_weight
,
"dgrad_grad"
:
dgrad_grad
,
"wgrad_input"
:
wgrad_input
,
"wgrad_grad"
:
wgrad_grad
,
}
fp8_kwargs
=
{
"fprop_fp8"
:
True
,
"dgrad_fp8"
:
True
,
"wgrad_fp8"
:
True
,
}
"""
Runs a test to validate per-tensor (current) scaling in FP8 computations.
The function performs warm-up iterations to populate the amax buffer of the model and compute scaling factors based on delayed scaling.
Subsequently, weights and inputs are switched to ensure their current scaling factors differ from those based on delayed scaling;
similarly, the loss is multiplied by a large factor to alter the gradient's magnitude,
creating a discrepancy between the original (delayed) and per-tensor (current) scaling factors.
Finally, a linear pass is emulated, and the results are compared.”
"""
_prepare_per_tensor_scaling_config
(
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
,
kwargs
[
"config_file"
],
)
_init_debug
(
kwargs
[
"config_file"
].
name
,
kwargs
[
"log_dir"
],
FEATURE_DIRS
)
warmup_input
,
warmup_weight
=
_get_tensors
(
parallel_mode
=
parallel_mode
)
model
=
_init_model
(
warmup_weight
,
parallel_mode
=
parallel_mode
)
# Warmup run to setup amax and scaling factors.
for
_
in
range
(
AMAX_HISTORY_LEN
):
_run_forward_backward
(
warmup_input
,
model
,
parallel_mode
=
parallel_mode
)
x
,
weight
=
_get_tensors
(
parallel_mode
=
parallel_mode
,
weight_seed
=
WORLD_RANK
*
2137
,
data_seed
=
WORLD_RANK
*
2137
)
model
.
weight
.
data
=
weight
.
data
x
.
retain_grad
()
# delayed scaling factor
# need to be collected before forward pass with test data,
# because this forward pass changes scaling factors
set_scaling_factors
(
model
,
input_kwargs
,
fp8_kwargs
)
LOSS_MULTIPLIER
=
100
with
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
FP8_RECIPE
):
y
=
model
(
x
)
model
.
zero_grad
()
if
parallel_mode
==
"column"
:
y
=
AllGather
.
apply
(
y
,
-
1
)
y
.
retain_grad
()
(
LOSS_MULTIPLIER
*
y
.
sum
()
).
backward
()
# Loss multiplication to change gradient's order of magintude
output
=
{
"activation"
:
y
.
clone
(),
"wgrad"
:
model
.
weight
.
grad
.
clone
(),
"dgrad"
:
x
.
grad
.
clone
()}
# per tensor - current - scaling factors
# need to be collected after forward pass with test data,
# because gradient(y.grad) cannot be accessed before forward,
# but it needs to be collected.
set_current_scaling_factors
(
x
,
weight
,
y
,
input_kwargs
,
fp8_kwargs
)
ground_truth
=
_emulate_linear_distributed
(
x
,
weight
,
parallel_mode
=
parallel_mode
,
loss_multiplier
=
LOSS_MULTIPLIER
,
**
fp8_kwargs
)
_cmp
(
ground_truth
,
output
)
@
run_debug_test
def
test_fake_quant_fp8
(
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
,
parallel_mode
,
**
kwargs
,
):
fp8_kwargs
=
{
"fprop_input_fake_quant"
:
fprop_inp
,
"fprop_weight_fake_quant"
:
fprop_weight
,
"dgrad_gradient_fake_quant"
:
dgrad_grad
,
"dgrad_weight_fake_quant"
:
dgrad_weight
,
"wgrad_gradient_fake_quant"
:
wgrad_grad
,
"wgrad_input_fake_quant"
:
wgrad_input
,
"fprop_fp8"
:
not
(
fprop_inp
or
fprop_weight
),
"dgrad_fp8"
:
not
(
dgrad_weight
or
dgrad_grad
),
"wgrad_fp8"
:
not
(
wgrad_grad
or
wgrad_input
),
}
if
WORLD_RANK
==
0
:
fake_quant_fp8_create_config
(
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
,
kwargs
[
"config_file"
],
)
dist
.
barrier
()
_init_debug
(
kwargs
[
"config_file"
].
name
,
kwargs
[
"log_dir"
],
FEATURE_DIRS
)
x
,
weight
=
_get_tensors
(
parallel_mode
)
model
=
_init_model
(
weight
,
parallel_mode
)
y
=
_run_forward_backward
(
x
,
model
,
parallel_mode
)
output
=
{
"activation"
:
y
.
clone
(),
"wgrad"
:
model
.
weight
.
grad
.
clone
(),
"dgrad"
:
x
.
grad
.
clone
()}
fp8_kwargs
[
"fprop_input_scale"
]
=
(
_get_current_scale
(
x
,
fprop_inp
)
if
not
fp8_kwargs
[
"fprop_fp8"
]
else
None
)
fp8_kwargs
[
"fprop_weight_scale"
]
=
(
_get_current_scale
(
weight
,
fprop_weight
)
if
not
fp8_kwargs
[
"fprop_fp8"
]
else
None
)
fp8_kwargs
[
"dgrad_gradient_scale"
]
=
(
_get_current_scale
(
y
.
grad
,
dgrad_grad
)
if
not
fp8_kwargs
[
"dgrad_fp8"
]
else
None
)
fp8_kwargs
[
"dgrad_weight_scale"
]
=
(
_get_current_scale
(
weight
,
dgrad_weight
)
if
not
fp8_kwargs
[
"dgrad_fp8"
]
else
None
)
fp8_kwargs
[
"wgrad_gradient_scale"
]
=
(
_get_current_scale
(
y
.
grad
,
wgrad_grad
)
if
not
fp8_kwargs
[
"wgrad_fp8"
]
else
None
)
fp8_kwargs
[
"wgrad_input_scale"
]
=
(
_get_current_scale
(
x
,
wgrad_input
)
if
not
fp8_kwargs
[
"wgrad_fp8"
]
else
None
)
ground_truth
=
_emulate_linear_distributed
(
x
,
weight
,
parallel_mode
=
parallel_mode
,
**
fp8_kwargs
)
_cmp
(
ground_truth
,
output
)
def
_init_distributed
():
global
WORLD_RANK
,
WORLD_SIZE
,
NCCL_WORLD
,
FP8
WORLD_RANK
=
int
(
os
.
getenv
(
"RANK"
,
"0"
))
WORLD_SIZE
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
LOCAL_RANK
=
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
LOCAL_SIZE
=
int
(
os
.
getenv
(
"LOCAL_WORLD_SIZE"
,
"1"
))
assert
WORLD_SIZE
==
LOCAL_SIZE
# this test supports only 1 node
assert
LOCAL_SIZE
<=
torch
.
cuda
.
device_count
()
dist_init_kwargs
=
{
"backend"
:
"nccl"
,
"rank"
:
WORLD_RANK
,
"world_size"
:
WORLD_SIZE
,
}
dist_init_kwargs
[
"init_method"
]
=
"env://"
dist_init_kwargs
[
"device_id"
]
=
torch
.
device
(
f
"cuda:
{
LOCAL_RANK
}
"
)
assert
dist
.
is_nccl_available
()
torch
.
cuda
.
set_device
(
LOCAL_RANK
)
dist
.
init_process_group
(
**
dist_init_kwargs
)
NCCL_WORLD
=
dist
.
new_group
(
backend
=
"nccl"
)
WORLD_SIZE
=
dist
.
get_world_size
()
def
_run_test_with_combinations
(
test_function
,
values_list
,
num_repeat
,
extra_args
,
sample_size
=
None
):
combinations
=
itertools
.
product
(
values_list
,
repeat
=
num_repeat
)
total_combinations
=
itertools
.
product
(
combinations
,
extra_args
)
if
sample_size
is
not
None
:
total_combinations
=
random
.
sample
(
list
(
total_combinations
),
sample_size
)
for
comb
,
arg
in
total_combinations
:
test_function
(
*
comb
,
arg
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--feature_dirs"
,
type
=
str
)
args
=
parser
.
parse_args
()
FEATURE_DIRS
=
args
.
feature_dirs
random
.
seed
(
SEED
)
_init_distributed
()
test_log_expert_parallel
()
for
parallel_mode
in
[
"column"
,
"row"
]:
for
gather_weight
in
[
True
,
False
]:
test_log_distributed
(
parallel_mode
,
gather_weight
)
for
parallel_mode
in
[
"row"
,
"column"
]:
test_disable_fp8_layer
(
parallel_mode
)
# test_disable_fp8_gemms
_run_test_with_combinations
(
test_disable_fp8_gemms
,
all_boolean
,
num_repeat
=
3
,
extra_args
=
[
"column"
,
"row"
]
)
# test_fake_quant_fp8
dtype_options
=
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
,
None
]
_run_test_with_combinations
(
test_fake_quant_fp8
,
dtype_options
,
num_repeat
=
6
,
extra_args
=
[
"column"
,
"row"
],
sample_size
=
20
,
)
_run_test_with_combinations
(
test_per_tensor_scaling
,
all_boolean
,
num_repeat
=
6
,
extra_args
=
[
"column"
],
sample_size
=
20
,
)
tests/pytorch/debug/test_api_features.py
0 → 100644
View file @
2b05e121
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
torch
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
,
Float8Quantizer
import
nvdlfw_inspect.api
as
debug_api
try
:
import
transformer_engine
import
transformer_engine_torch
as
tex
except
(
ImportError
,
ModuleNotFoundError
):
print
(
"Could not find TransformerEngine package."
)
exit
(
1
)
def
test_transformer_engine_no_config
(
feature_dirs
):
debug_api
.
initialize
(
""
,
feature_dirs
=
feature_dirs
)
try
:
tensor
=
torch
.
rand
(
24
,
2046
).
cuda
()
# FP8 enabled - true by the default
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
)
# modify_tensor_enabled - False by default
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
)
# inspect_tensor_enabled - False by default
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.1.attn.qkv"
,
tensor_name
=
"activation"
,
iteration
=
0
)
# inspect_tensor_postquantize - False by default
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_postquantize_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
)
finally
:
debug_api
.
end_debug
()
def
test_disable_fp8_gemm
(
configs_dir
,
feature_dirs
):
try
:
debug_api
.
initialize
(
configs_dir
+
"disable_fp8_gemms.yaml"
,
feature_dirs
=
feature_dirs
)
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
)
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"dgrad"
,
iteration
=
0
)
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"wgrad"
,
iteration
=
0
)
# caching
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
)
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"dgrad"
,
iteration
=
0
)
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"wgrad"
,
iteration
=
0
)
finally
:
debug_api
.
end_debug
()
def
test_disable_fp8_layer
(
configs_dir
,
feature_dirs
):
try
:
debug_api
.
initialize
(
configs_dir
+
"disable_fp8_layer.yaml"
,
feature_dirs
=
feature_dirs
)
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
iteration
=
0
)
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"wgrad"
,
iteration
=
0
)
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
iteration
=
0
)
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"fprop"
,
iteration
=
0
)
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"wgrad"
,
iteration
=
0
)
assert
not
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.attn.qkv"
,
gemm
=
"dgrad"
,
iteration
=
0
)
finally
:
debug_api
.
end_debug
()
def
test_per_tensor_scaling
(
configs_dir
,
feature_dirs
):
try
:
debug_api
.
initialize
(
configs_dir
+
"per_tensor_scaling.yaml"
,
feature_dirs
=
feature_dirs
)
tensor
=
torch
.
rand
(
24
,
2046
).
cuda
()
# check modify_tensor_enabled
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
)
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
tensor_name
=
"weight"
,
iteration
=
0
)
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
tensor_name
=
"gradient"
,
iteration
=
0
)
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
tensor_name
=
"weight"
,
iteration
=
0
)
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"wgrad"
,
tensor_name
=
"gradient"
,
iteration
=
0
)
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"wgrad"
,
tensor_name
=
"activation"
,
iteration
=
0
)
# check modify_tensor
default_quantizer1
=
Float8Quantizer
(
scale
=
torch
.
tensor
([
1
]).
cuda
(),
amax
=
torch
.
tensor
([
0
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
default_quantizer2
=
Float8Quantizer
(
scale
=
torch
.
tensor
([
1
]).
cuda
(),
amax
=
torch
.
tensor
([
0
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
,
)
output1
=
debug_api
.
transformer_engine
.
modify_tensor
(
layer_name
=
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
default_quantizer
=
default_quantizer1
,
iteration
=
0
,
tensor
=
tensor
,
)
assert
type
(
output1
)
==
Float8Tensor
assert
output1
.
_fp8_dtype
==
tex
.
DType
.
kFloat8E4M3
output2
=
debug_api
.
transformer_engine
.
modify_tensor
(
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
tensor
=
tensor
,
tensor_name
=
"gradient"
,
default_quantizer
=
default_quantizer2
,
iteration
=
0
,
)
assert
type
(
output2
)
==
Float8Tensor
assert
output2
.
_fp8_dtype
==
tex
.
DType
.
kFloat8E5M2
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"wgrad"
,
tensor_name
=
"gradient"
,
iteration
=
0
,
)
assert
not
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc4"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
,
)
finally
:
debug_api
.
end_debug
()
def
test_fake_quant
(
configs_dir
,
feature_dirs
):
try
:
debug_api
.
initialize
(
configs_dir
+
"fake_quantization_config.yaml"
,
feature_dirs
=
feature_dirs
)
tensor
=
torch
.
rand
(
24
,
2046
).
cuda
()
# modify_tensor_enabled
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
iteration
=
0
)
assert
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
tensor_name
=
"gradient"
,
iteration
=
0
)
# modify_tensor
debug_api
.
transformer_engine
.
modify_tensor
(
"decoder.1.mlp.fc1"
,
gemm
=
"fprop"
,
tensor
=
tensor
,
tensor_name
=
"activation"
,
iteration
=
0
,
default_quantizer
=
None
,
)
debug_api
.
transformer_engine
.
modify_tensor
(
"decoder.1.mlp.fc1"
,
gemm
=
"dgrad"
,
tensor
=
tensor
,
tensor_name
=
"gradient"
,
iteration
=
0
,
default_quantizer
=
None
,
)
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.fc2"
,
gemm
=
"wgrad"
,
iteration
=
0
)
# caching
assert
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
"decoder.1.fc2"
,
gemm
=
"wgrad"
,
iteration
=
0
)
finally
:
debug_api
.
end_debug
()
def
test_statistics_collection
(
configs_dir
,
feature_dirs
):
try
:
debug_api
.
initialize
(
config_file
=
configs_dir
+
"stats_collection_test_config.yaml"
,
feature_dirs
=
feature_dirs
,
default_logging_enabled
=
False
,
)
tensor
=
torch
.
randn
((
100
,
100
,
5
)).
cuda
()
tensor_fp8
=
Float8Tensor
(
data
=
tensor
.
to
(
torch
.
uint8
).
cuda
(),
fp8_scale_inv
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
shape
=
tensor
.
shape
,
dtype
=
torch
.
float32
,
)
def
log
():
from
transformer_engine.debug.features.utils.stats_buffer
import
STATS_BUFFERS
return
STATS_BUFFERS
.
log_stats
()
def
assert_empty
():
stats
=
log
()
assert
len
(
stats
)
==
0
# TE tensor stats --
debug_api
.
transformer_engine
.
inspect_tensor
(
"decoder.1.mlp.fc1"
,
tensor
=
tensor
,
tensor_name
=
"activation"
,
iteration
=
200
,
tp_group
=
None
,
)
stats
=
log
()
assert
stats
[(
"decoder.1.mlp.fc1"
,
"activation"
,
"cur_amax"
,
200
)]
==
tensor
.
abs
().
max
()
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"activation"
,
iteration
=
201
)
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.2.mlp.fc1"
,
tensor_name
=
"activation"
,
iteration
=
200
)
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"gradient"
,
iteration
=
200
)
expected_underflows
=
(
tensor_fp8
.
_data
==
0
).
sum
()
*
100
/
(
100
*
100
*
5
)
expected_overflows
=
(
tensor_fp8
.
_data
==
126
).
sum
()
*
100
/
(
100
*
100
*
5
)
# TE FP8 tensor stats --
assert
debug_api
.
transformer_engine
.
inspect_tensor_postquantize_enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"gradient"
,
gemm
=
"wgrad"
,
iteration
=
200
)
debug_api
.
transformer_engine
.
inspect_tensor_postquantize
(
"decoder.1.mlp.fc1"
,
tensor
=
tensor_fp8
,
tensor_name
=
"gradient"
,
iteration
=
200
,
rowwise
=
True
,
tp_group
=
None
,
)
stats
=
log
()
torch
.
testing
.
assert_close
(
stats
[(
"decoder.1.mlp.fc1"
,
"gradient"
,
"underflows%"
,
200
)],
expected_underflows
)
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_postquantize_enabled
(
"decoder.1.mlp.fc1"
,
tensor_name
=
"activation"
,
gemm
=
"fprop"
,
iteration
=
201
)
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_postquantize_enabled
(
"decoder.2.mlp.fc1"
,
tensor_name
=
"gradient"
,
gemm
=
"wgrad"
,
iteration
=
200
)
# Second config in same yaml
tensor
=
torch
.
rand
((
100
,
100
,
5
))
debug_api
.
transformer_engine
.
inspect_tensor
(
"decoder.6.mlp.fc1"
,
tensor
=
tensor
,
tensor_name
=
"activation"
,
iteration
=
200
,
tp_group
=
None
,
)
stats
=
log
()
stats_names
=
[
x
[
3
]
for
x
in
stats
.
keys
()]
all
(
s
in
stats_names
for
s
in
[
"cur_amax"
,
"dynamic_range"
,
"mean"
,
"std"
,
"l1_norm"
])
assert
stats
[(
"decoder.6.mlp.fc1"
,
"activation"
,
"mean"
,
200
)]
==
tensor
.
mean
()
debug_api
.
transformer_engine
.
inspect_tensor
(
"decoder.7.mlp.fc1"
,
tensor
=
tensor
,
tensor_name
=
"weight"
,
iteration
=
200
,
tp_group
=
None
,
)
stats
=
log
()
stats_names
=
[
x
[
3
]
for
x
in
stats
.
keys
()]
all
(
s
in
stats_names
for
s
in
[
"mean"
,
"std"
,
"l1_norm"
,
"min"
,
"max"
])
assert
stats
[(
"decoder.7.mlp.fc1"
,
"weight"
,
"max"
,
200
)]
==
tensor
.
max
()
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.7.mlp.fc1"
,
tensor_name
=
"weight"
,
iteration
=
201
)
assert_empty
()
finally
:
debug_api
.
end_debug
()
def
test_statistics_multi_run
(
configs_dir
,
feature_dirs
):
try
:
debug_api
.
initialize
(
config_file
=
configs_dir
+
"stats_collection_test_config.yaml"
,
feature_dirs
=
feature_dirs
,
default_logging_enabled
=
False
,
)
def
feed
(
tensor
,
tensor_fp8
):
debug_api
.
transformer_engine
.
inspect_tensor
(
"decoder.5.mlp.fc1"
,
tensor
=
tensor
,
tensor_name
=
"activation"
,
iteration
=
1
,
tp_group
=
None
,
)
debug_api
.
transformer_engine
.
inspect_tensor_postquantize
(
"decoder.5.mlp.fc1"
,
tensor
=
tensor_fp8
,
tensor_name
=
"activation"
,
iteration
=
1
,
rowwise
=
True
,
tp_group
=
None
,
)
def
log_stats
():
from
transformer_engine.debug.features.utils.stats_buffer
import
STATS_BUFFERS
return
STATS_BUFFERS
.
log_stats
()
def
fp8_tensor
(
t
):
return
Float8Tensor
(
data
=
t
.
to
(
torch
.
uint8
).
cuda
(),
fp8_scale_inv
=
torch
.
ones
([
1
]).
cuda
(),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
shape
=
t
.
shape
,
dtype
=
torch
.
float32
,
)
shape
=
[
1024
,
1024
]
tensors
=
[
torch
.
randn
(
shape
)
for
_
in
range
(
2
)]
tensors_fp8
=
[
fp8_tensor
(
tensors
[
i
])
for
i
in
range
(
2
)]
feed
(
tensors
[
0
],
tensors_fp8
[
0
])
feed
(
tensors
[
1
],
tensors_fp8
[
1
])
stats1
=
log_stats
()
tensor2
=
torch
.
cat
((
tensors
[
0
],
tensors
[
1
])).
cuda
()
fp8tensor2
=
fp8_tensor
(
tensor2
)
feed
(
tensor2
,
fp8tensor2
)
stats2
=
log_stats
()
assert
len
(
stats1
.
keys
())
>
0
for
k
in
stats1
.
keys
():
torch
.
testing
.
assert_close
(
stats1
[
k
],
stats2
[
k
])
finally
:
debug_api
.
end_debug
()
if
__name__
==
"__main__"
:
pass
tests/pytorch/debug/test_config.py
0 → 100644
View file @
2b05e121
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pathlib
,
os
from
nvdlfw_inspect.config_manager
import
ConfigManager
import
nvdlfw_inspect.api
as
debug_api
try
:
import
transformer_engine
from
transformer_engine.debug.features.api
import
TEConfigAPIMapper
except
(
ImportError
,
ModuleNotFoundError
):
print
(
"Could not find TransformerEngine debug module."
)
exit
(
1
)
def
test_transformer_engine_config_parsing
(
feature_dirs
):
debug_api
.
initialize
(
config_file
=
pathlib
.
Path
(
__file__
).
resolve
().
parent
/
"test_configs/tensor_manipulation_transformer_engine.yaml"
,
feature_dirs
=
feature_dirs
,
log_dir
=
"./log"
,
)
cfg_fc1
=
ConfigManager
.
get_config_for_layer
(
"decoder.1.mlp.fc1"
)[
"transformer_engine"
]
cfg_fc2
=
ConfigManager
.
get_config_for_layer
(
"decoder.1.mlp.fc2"
)[
"transformer_engine"
]
assert
cfg_fc1
and
cfg_fc2
gemm_parsing
=
True
tensor_parsing
=
True
# Per tensor scaling set for dgrad, filter based on gemm
ret
,
_
=
TEConfigAPIMapper
().
parse_config_and_api
(
cfg_fc1
[
"PerTensorScaling"
],
gemm_parsing
=
gemm_parsing
,
tensor_parsing
=
tensor_parsing
,
gemm
=
"wgrad"
,
tensor_name
=
"activation"
,
)
assert
not
ret
# per tensor scaling set for gradient, filter based on tensor name
ret
,
_
=
TEConfigAPIMapper
().
parse_config_and_api
(
cfg_fc1
[
"PerTensorScaling"
],
gemm_parsing
=
gemm_parsing
,
tensor_parsing
=
tensor_parsing
,
gemm
=
"dgrad"
,
tensor_name
=
"activation"
,
)
assert
not
ret
ret
,
parsed_cfg_fc1
=
TEConfigAPIMapper
().
parse_config_and_api
(
cfg_fc1
[
"PerTensorScaling"
],
gemm_parsing
=
gemm_parsing
,
tensor_parsing
=
tensor_parsing
,
gemm
=
"dgrad"
,
tensor_name
=
"gradient"
,
)
assert
ret
assert
parsed_cfg_fc1
==
{
"gemm"
:
"dgrad"
,
"tensor"
:
"gradient"
}
# Test tensor struct
ret
,
parsed_cfg_fc1_act
=
TEConfigAPIMapper
().
parse_config_and_api
(
cfg_fc1
[
"FakeQuant"
],
gemm_parsing
=
gemm_parsing
,
tensor_parsing
=
tensor_parsing
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
)
ret
,
parsed_cfg_fc1_wei
=
TEConfigAPIMapper
().
parse_config_and_api
(
cfg_fc1
[
"FakeQuant"
],
gemm_parsing
=
gemm_parsing
,
tensor_parsing
=
tensor_parsing
,
gemm
=
"fprop"
,
tensor_name
=
"weight"
,
)
assert
ret
assert
parsed_cfg_fc1_act
==
{
"gemm"
:
"fprop"
,
"tensor"
:
"activation"
,
"quant_format"
:
"FP8E4M3"
,
}
assert
parsed_cfg_fc1_wei
==
{
"gemm"
:
"fprop"
,
"tensor"
:
"weight"
,
"quant_format"
:
"FP8E4M3"
,
}
# Test gemms struct
ret
,
parsed_cfg_fc2_grad
=
TEConfigAPIMapper
().
parse_config_and_api
(
cfg_fc2
[
"FakeQuant"
],
gemm_parsing
=
gemm_parsing
,
tensor_parsing
=
tensor_parsing
,
gemm
=
"dgrad"
,
tensor_name
=
"gradient"
,
)
assert
ret
assert
parsed_cfg_fc2_grad
==
{
"gemm"
:
"dgrad"
,
"tensor"
:
"gradient"
,
"quant_format"
:
"FP8E5M2"
}
ret
,
parsed_cfg_fc2_wei
=
TEConfigAPIMapper
().
parse_config_and_api
(
cfg_fc2
[
"FakeQuant"
],
gemm_parsing
=
gemm_parsing
,
tensor_parsing
=
tensor_parsing
,
gemm
=
"dgrad"
,
tensor_name
=
"weight"
,
)
assert
ret
assert
parsed_cfg_fc2_wei
==
{
"gemm"
:
"dgrad"
,
"tensor"
:
"weight"
,
"quant_format"
:
"FP8E5M2"
}
# Test gemm + tensor struct
ret
,
parsed_cfg_fc2_fprop_act
=
TEConfigAPIMapper
().
parse_config_and_api
(
cfg_fc2
[
"PerTensorScaling"
],
gemm_parsing
=
gemm_parsing
,
tensor_parsing
=
tensor_parsing
,
gemm
=
"fprop"
,
tensor_name
=
"activation"
,
)
assert
ret
assert
parsed_cfg_fc2_fprop_act
==
{
"gemm"
:
"fprop"
,
"tensor"
:
"activation"
}
ret
,
parsed_cfg_fc2_fprop_wei
=
TEConfigAPIMapper
().
parse_config_and_api
(
cfg_fc2
[
"PerTensorScaling"
],
gemm_parsing
=
gemm_parsing
,
tensor_parsing
=
tensor_parsing
,
gemm
=
"fprop"
,
tensor_name
=
"weight"
,
)
assert
ret
assert
parsed_cfg_fc2_fprop_wei
==
{
"gemm"
:
"fprop"
,
"tensor"
:
"weight"
}
ret
,
parsed_cfg_fc2_wgrad_act
=
TEConfigAPIMapper
().
parse_config_and_api
(
cfg_fc2
[
"PerTensorScaling"
],
gemm_parsing
=
gemm_parsing
,
tensor_parsing
=
tensor_parsing
,
gemm
=
"wgrad"
,
tensor_name
=
"activation"
,
)
assert
ret
assert
parsed_cfg_fc2_wgrad_act
==
{
"gemm"
:
"wgrad"
,
"tensor"
:
"activation"
}
ret
,
parsed_cfg_fc2_wgrad_grad
=
TEConfigAPIMapper
().
parse_config_and_api
(
cfg_fc2
[
"PerTensorScaling"
],
gemm_parsing
=
gemm_parsing
,
tensor_parsing
=
tensor_parsing
,
gemm
=
"wgrad"
,
tensor_name
=
"gradient"
,
)
assert
ret
assert
parsed_cfg_fc2_wgrad_grad
==
{
"gemm"
:
"wgrad"
,
"tensor"
:
"gradient"
}
ConfigManager
.
reset
()
tests/pytorch/debug/test_configs/disable_fp8_gemms.yaml
0 → 100644
View file @
2b05e121
test_disable_fp8_gemm_1
:
enabled
:
True
layers
:
layer_types
:
[
qkv
,
fc2
]
transformer_engine
:
DisableFP8GEMM
:
enabled
:
True
gemms
:
[
dgrad
,
wgrad
]
\ No newline at end of file
Prev
1
2
3
4
5
6
7
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment