Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
......@@ -24,10 +24,10 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
wait
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
wait
# python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
# wait
# python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
# wait
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
if [ $RET -ne 0 ]; then
......
......@@ -27,9 +27,6 @@ mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*"
# Test without custom calls
NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
......@@ -37,6 +34,9 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
: ${TE_PATH:=/opt/transformerengine}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
: ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/}
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
FAIL=0
pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
exit $FAIL
......@@ -20,5 +20,5 @@ if [ -z "${CPP_ONLY}" ]
then
cd $TE_PATH
echo "Checking Python files"
python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch
python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch transformer_engine/debug
fi
......@@ -47,6 +47,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entro
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
......
......@@ -20,6 +20,7 @@ FAILED_CASES=""
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
......@@ -30,6 +31,19 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
# debug tests
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
exit 1
......
......@@ -25,18 +25,18 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
# Test without custom calls
NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS_RE="" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
......
......@@ -19,11 +19,7 @@ from build_tools.te_version import te_version
from build_tools.utils import (
rocm_build,
cuda_archs,
found_cmake,
found_ninja,
found_pybind11,
get_frameworks,
install_and_import,
remove_dups,
)
......@@ -38,7 +34,6 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1"
if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension
......@@ -87,6 +82,11 @@ def setup_common_extension() -> CMakeExtension:
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")
# Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
if nvte_cmake_extra_args:
cmake_flags.extend(nvte_cmake_extra_args.split())
# Project directory root
root_path = Path(__file__).resolve().parent
if rocm_build():
......@@ -102,22 +102,13 @@ def setup_common_extension() -> CMakeExtension:
)
def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
def setup_requirements() -> Tuple[List[str], List[str]]:
"""Setup Python dependencies
Returns dependencies for build, runtime, and testing.
Returns dependencies for runtime and testing.
"""
# Common requirements
setup_reqs: List[str] = [
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
]
install_reqs: List[str] = [
"pydantic",
"importlib-metadata>=1.0",
......@@ -125,32 +116,20 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
]
test_reqs: List[str] = ["pytest>=8.2.1"]
# Requirements that may be installed outside of Python
if not found_cmake():
setup_reqs.append("cmake>=3.21")
if not found_ninja():
setup_reqs.append("ninja")
if not found_pybind11():
setup_reqs.append("pybind11")
# Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
setup_reqs.extend(["torch>=2.1"])
install_reqs.extend(["torch>=2.1"])
# install_reqs.append(
# "nvdlfw-inspect @"
# " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
# )
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision"])
from build_tools.pytorch import install_requirements, test_requirements
install_reqs.extend(install_requirements())
test_reqs.extend(test_requirements())
if "jax" in frameworks:
setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"])
install_reqs.extend(["jax", "flax>=0.7.1"])
test_reqs.extend(["numpy"])
from build_tools.jax import install_requirements, test_requirements
install_reqs.extend(install_requirements())
test_reqs.extend(test_requirements())
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
if __name__ == "__main__":
......@@ -167,14 +146,13 @@ if __name__ == "__main__":
ext_modules = []
package_data = {}
include_package_data = False
setup_requires = []
install_requires = ([f"transformer_engine_cu12=={__version__}"],)
extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
}
else:
setup_requires, install_requires, test_requires = setup_requirements()
install_requires, test_requires = setup_requirements()
ext_modules = [setup_common_extension()]
package_data = {"": ["VERSION.txt"]}
include_package_data = True
......@@ -219,15 +197,8 @@ if __name__ == "__main__":
long_description_content_type="text/x-rst",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=">=3.8, <3.13",
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
setup_requires=setup_requires,
python_requires=">=3.8",
classifiers=["Programming Language :: Python :: 3"],
install_requires=install_requires,
license_files=("LICENSE",),
include_package_data=include_package_data,
......
......@@ -375,7 +375,7 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = {
{256, 256},
{993, 512},
{768, 1024},
{65536, 128},
{65504, 128},
{16384, 1632},
};
......
......@@ -71,7 +71,8 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
// Remove the use_cudnn check here when it is supported by both backends.
const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3> ||
std::is_same_v<InputType, fp4e2m1>){
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
......
......@@ -45,7 +45,7 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
return true;
}
size_t typeToSize(DType type) {
size_t typeToNumBits(DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
{
return TypeInfo<T>::size;
......@@ -62,7 +62,8 @@ const std::string &typeName(DType type) {
{DType::kBFloat16, "bfloat16"},
{DType::kFloat8E4M3, "float8e4m3"},
{DType::kFloat8E5M2, "float8e5m2"},
{DType::kFloat8E8M0, "float8e8m0"}};
{DType::kFloat8E8M0, "float8e8m0"},
{DType::kFloat4E2M1, "float4e2m1"}};
return name_map.at(type);
}
......@@ -109,9 +110,16 @@ size_t DIVUP(const size_t &x, const size_t &y){
struct scale_inv_meta {
std::vector<size_t> shape;
DType type;
size_t type_size;
size_t type_size_bits;
size_t bytes() const noexcept {
return (product(shape) * type_size_bits) / 8;
}
};
size_t bytes(const NVTEShape& shape, const DType type) {
return (product(shape) * typeToNumBits(type)) / 8;
}
NVTEShape convertShape(const std::vector<size_t>& s) {
return nvte_make_shape(s.data(), s.size());
}
......@@ -122,7 +130,7 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret;
ret.shape = {1};
ret.type = DType::kFloat32;
ret.type_size = sizeof(float);
ret.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret, ret};
}
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
......@@ -152,8 +160,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
}
ret_rowwise.type = DType::kFloat8E8M0;
ret_colwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size = sizeof(uint8_t);
ret_colwise.type_size = sizeof(uint8_t);
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
return {ret_rowwise, ret_colwise};
}
......@@ -179,8 +187,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
}
ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float);
ret_colwise.type_size = sizeof(float);
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret_rowwise, ret_colwise};
}
......@@ -205,8 +213,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
}
ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float);
ret_colwise.type_size = sizeof(float);
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret_rowwise, ret_colwise};
}
......@@ -222,8 +230,7 @@ Tensor::Tensor(const std::string& name,
gen_.seed(seed);
rowwise_ = rowwise;
columnwise_ = columnwise;
size_t s = typeToSize(type);
size_t total_size = product(shape) * s;
size_t total_size = bytes(shape, type);
void *dptr_rowwise = nullptr;
void *dptr_columnwise = nullptr;
cpu_data_rowwise_ = nullptr;
......@@ -305,8 +312,8 @@ Tensor::Tensor(const std::string& name,
} else {
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(normalized_shape, tensor_.scaling_mode());
auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
auto rowwise_scale_size = rowwise_scale_meta.bytes();
auto columnwise_scale_size = colwise_scale_meta.bytes();
auto scale_shape = rowwise_scale_meta.shape;
auto columnwise_scale_shape = colwise_scale_meta.shape;
if (rowwise) {
......@@ -331,7 +338,7 @@ Tensor::Tensor(const std::string& name,
void Tensor::to_cpu() const {
const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype());
const size_t size = bytes(s, tensor_.dtype());
if (rowwise_) {
cudaMemcpy(cpu_data_rowwise_.get(),
tensor_.get_rowwise_data().data_ptr,
......@@ -360,14 +367,14 @@ void Tensor::to_cpu() const {
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
tensor_.get_rowwise_scale_inv().data_ptr,
scale_size,
cudaMemcpyDeviceToHost);
}
if (columnwise_) {
auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
auto scale_size = colwise_scale_meta.bytes();
cudaMemcpy(columnwise_scale_inv_cpu_data_.get(),
tensor_.get_columnwise_scale_inv().data_ptr,
scale_size,
......@@ -378,34 +385,32 @@ void Tensor::to_cpu() const {
void Tensor::from_cpu() const {
const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype());
const size_t size = bytes(s, tensor_.dtype());
if (rowwise_) {
cudaMemcpy(tensor_.get_rowwise_data().data_ptr,
cpu_data_rowwise_.get(), size, cudaMemcpyHostToDevice);
cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size,
cudaMemcpyHostToDevice);
}
if (columnwise_) {
cudaMemcpy(tensor_.get_columnwise_data().data_ptr,
cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice);
cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
cudaMemcpyHostToDevice);
}
if (isFp8Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (tensor_.amax() != nullptr){
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice);
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
}
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice);
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
}
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
rowwise_scale_inv_cpu_data_.get(), scale_size,
cudaMemcpyHostToDevice);
}
if (columnwise_) {
auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
auto scale_size = colwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr,
columnwise_scale_inv_cpu_data_.get(), scale_size,
cudaMemcpyHostToDevice);
......@@ -735,6 +740,19 @@ std::pair<double, double> getTolerances(const DType type) {
template <typename T>
void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
// Check how many RNG calls are required to generate one uniform random value
int rng_calls_per_val = 0;
{
std::mt19937 gen1 = *gen, gen2 = *gen;
std::uniform_real_distribution<> dis(-2.0, 1.0);
const float _ = dis(gen1);
while (gen2 != gen1) {
auto _ = gen2();
++rng_calls_per_val;
}
}
// Generate uniform random values in parallel
#pragma omp parallel proc_bind(spread)
{
std::mt19937 gen_local = *gen;
......@@ -743,7 +761,7 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
const int chunk_size = (size + threads_num - 1) / threads_num;
const int idx_min = chunk_size * thread_ID;
const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size));
gen_local.discard(idx_min);
gen_local.discard(idx_min * rng_calls_per_val);
std::uniform_real_distribution<> dis(-2.0, 1.0);
for (int i = idx_min; i < idx_max; ++i) {
......@@ -754,7 +772,7 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
#endif
}
}
gen->discard(size);
gen->discard(size * rng_calls_per_val);
}
void fillUniform(Tensor *t) {
......
......@@ -10,11 +10,18 @@
#include <vector>
#include <array>
#include <random>
#include <cudaTypedefs.h>
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#include <cuda_runtime_api.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif
#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"
......@@ -56,20 +63,32 @@ using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using fp8e8m0 = uint8_t;
using int8 = int8_t;
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
#endif
template <typename T>
struct BitsNumber;
#if FP4_TYPE_SUPPORTED
template <>
struct BitsNumber<fp4e2m1> {
static constexpr size_t num_bits = 4;
};
#endif
template <typename T>
struct BitsNumber {
static constexpr size_t num_bits = 8 * sizeof(T);
};
template <typename T>
struct TypeInfo{
using types = std::tuple<byte,
int16,
int32,
int64,
fp32,
fp16,
bf16,
fp8e4m3,
fp8e5m2,
fp8e8m0,
int8>;
struct TypeInfo {
#if FP4_TYPE_SUPPORTED
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, fp4e2m1>;
#else
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, int8>;
#endif
template <typename U, DType current>
struct Helper {
......@@ -96,7 +115,7 @@ struct TypeInfo{
}
constexpr static DType dtype = getType<T>();
constexpr static size_t size = sizeof(T);
constexpr static size_t size = BitsNumber<T>::num_bits;;
};
class Tensor {
......@@ -418,9 +437,10 @@ inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); }
inline float srelu(const float x) { return x > 0 ? x * x : 0; }
inline float dsrelu(const float x) { return fmaxf(0, 2 * x); }
size_t typeToSize(DType type);
size_t typeToNumBits(DType type);
size_t product(const NVTEShape &shape);
size_t product(const std::vector<size_t> &shape);
size_t bytes(const NVTEShape& shape, const DType type);
size_t first_dimension(const std::vector<size_t> &shape);
size_t last_dimension(const std::vector<size_t> &shape);
......@@ -466,6 +486,16 @@ constexpr int32_t blackwellComputeCapability = 100;
} // namespace test
#if FP4_TYPE_SUPPORTED
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
case DType::kFloat4E2M1: { \
using type = fp4e2m1; \
{ __VA_ARGS__ } \
} break;
#else
#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing
#endif
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
......@@ -517,8 +547,16 @@ constexpr int32_t blackwellComputeCapability = 100;
{__VA_ARGS__} \
} \
break; \
case DType::kFloat8E8M0: \
{ \
using type = fp8e8m0; \
{__VA_ARGS__} \
} \
break; \
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type."); \
printf("dtype: %d\n", static_cast<int>(dtype)); \
NVTE_ERROR("Invalid type MARKED TEST."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
......@@ -537,7 +575,15 @@ constexpr int32_t blackwellComputeCapability = 100;
} \
break; \
default: \
NVTE_ERROR("Invalid type."); \
NVTE_ERROR("Invalid type MARKED TEST 2."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
SWITCH_FP4_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type MARKED TEST 3."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
......@@ -562,5 +608,5 @@ constexpr int32_t blackwellComputeCapability = 100;
} \
break; \
default: \
NVTE_ERROR("Invalid type."); \
NVTE_ERROR("Invalid type MARKED TEST 4."); \
}
......@@ -4,15 +4,14 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, value_and_grad
from functools import reduce
from typing import Union
import operator
from utils import (
assert_allclose,
assert_tree_like_allclose,
pytest_parametrize_wrapper,
)
from transformer_engine.jax.layernorm import layernorm
......@@ -33,15 +32,18 @@ from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import (
DelayedScaleQuantizer,
ScaledTensor,
ScaledTensor1x,
ScaledTensor2x,
GroupedScaledTensor1x,
ScalingMode,
QuantizerFactory,
QuantizeLayout,
noop_quantizer_set,
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense
from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x
GEMM_CASES = [
(256, 256, 512),
......@@ -53,8 +55,8 @@ GEMM_CASES = [
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
LN_CASES = [(256, 128), (128, 256)]
DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = helper.is_fp8_available()
is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available()
is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
supported_scaling_modes = []
""" Find supported scaling modes"""
......@@ -113,6 +115,38 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
pytest.fail("a must be a ScaledTensor object")
def assert_dequantized_grouped_scaled_tensor(
a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray
):
if isinstance(a, GroupedScaledTensor1x):
assert a.group_sizes.sum() == b.shape[0]
b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0)
dq_a = a.dequantize()
for dq_a_i, b_i in zip(dq_a, b):
if len(dq_a_i) == 0:
continue
if a.data_layout == "T":
data_ndim = len(a.original_shape)
flatten_axis = a.flatten_axis
if b_i.shape[0] == 1:
b_i = jnp.transpose(
b_i, (0, *range(flatten_axis, data_ndim), *range(1, flatten_axis))
)
else:
b_i = jnp.transpose(
b_i, (*range(flatten_axis, data_ndim), *range(flatten_axis))
)
dq_a_i = dq_a_i.reshape(b_i.shape)
assert_allclose(dq_a_i, b_i, dtype=a.data.dtype)
elif isinstance(a, ScaledTensor2x):
assert isinstance(a.get_rowwise_tensor(), GroupedScaledTensor1x)
assert isinstance(a.get_colwise_tensor(), GroupedScaledTensor1x)
assert_dequantized_grouped_scaled_tensor(a.get_rowwise_tensor(), b)
assert_dequantized_grouped_scaled_tensor(a.get_colwise_tensor(), b)
else:
pytest.fail("a must be a GroupedScaledTensor object")
ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)]
ALL_ACTIVATION_TYPES = [
("gelu",),
......@@ -173,7 +207,7 @@ class TestActivation:
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
......@@ -204,7 +238,7 @@ class TestActivation:
assert_allclose(prim_out, ref_out, dtype=output_type)
assert_allclose(prim_grad, ref_grad, dtype=output_type)
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
......@@ -234,7 +268,7 @@ class TestActivation:
assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
......@@ -355,7 +389,7 @@ class TestNorm:
n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
# No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper(
......@@ -470,7 +504,7 @@ class TestNorm:
if norm_type == "layernorm":
assert_allclose(mu, ref_mu, dtype=inp_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
# No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper(
......@@ -506,7 +540,7 @@ class TestNorm:
q_layout=q_layout,
)
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
def test_norm_forward_with_block_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
......@@ -532,7 +566,7 @@ QUANTIZE_OUTPUT_DTYPES = {
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
((32, 64), -1),
((2, 64, 32), -1),
((2, 64, 32), -2),
((64, 2, 32), -2),
((32, 256, 128), -1),
((32, 256, 128), -2),
((64, 32, 32, 256), -1),
......@@ -544,7 +578,7 @@ QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
"L0": [
((32, 64), -1),
((2, 64, 32), -1),
((2, 64, 32), -2),
((64, 2, 32), -2),
],
"L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES,
}
......@@ -555,7 +589,7 @@ QUANTIZATION_INPUT_DTYPE = {
}
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
......@@ -577,9 +611,6 @@ class TestQuantize:
q_dtype=q_dtype,
q_layout=q_layout,
)
# Adding dimension to test if padding is done correctly when flatten 3D to 2D
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
......@@ -593,8 +624,6 @@ class TestQuantize:
):
key = jax.random.PRNGKey(0)
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
input = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create(
......@@ -607,10 +636,65 @@ class TestQuantize:
assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1])
@pytest_parametrize_wrapper("with_group_sizes", [True, False])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
class TestGroupedQuantize:
def test_grouped_qdq(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes
):
n_groups, m, n = input_shape
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
# *32 so that the input shapes works for MXFP8
input_shape = (m * 32, n)
if with_group_sizes:
group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
group_sizes = jnp.diff(group_sizes)
assert group_sizes.sum() == m
assert jnp.any(group_sizes == 0) # make sure that at least one group has 0 row
group_sizes = group_sizes * 32
else:
group_sizes = None
input_shape = (n_groups, input_shape[0] // n_groups, input_shape[1])
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
x = jax.random.uniform(subkeys[1], input_shape, in_dtype)
grouped_quantizer = QuantizerFactory.create(
scaling_mode=scaling_mode,
q_dtype=q_dtype,
q_layout=q_layout,
n_groups=n_groups,
)
# grouped_quantize does not work with cudaGraph yet, so the jitting will breaks
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
# disable cudaGraph, then use the following jitted function
scaled_tensor = tex.grouped_quantize(
x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
)
assert_dequantized_grouped_scaled_tensor(scaled_tensor, x)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
class TestFusedQuantize:
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
......@@ -625,12 +709,6 @@ class TestFusedQuantize:
):
pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")
if (flatten_axis < 0 and flatten_axis + len(input_shape) <= 0) or flatten_axis <= 0:
pytest.skip(
f"Flatten axis {flatten_axis} is not supported for input shape {input_shape}. There"
" must be at least one axis on either side of the flatten_axis split."
)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
......@@ -717,7 +795,7 @@ class TestFusedQuantize:
q_layout=QuantizeLayout.ROWWISE,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
......@@ -741,7 +819,7 @@ class TestFusedQuantize:
q_layout=q_layout,
)
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper(
"input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)]
......@@ -810,7 +888,7 @@ class TestDense:
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
......@@ -852,7 +930,7 @@ class TestDense:
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
......@@ -916,7 +994,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class TestFusedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
......@@ -1001,7 +1079,7 @@ class TestFusedDense:
if beta is not None:
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
......@@ -1129,24 +1207,6 @@ class TestFusedDense:
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)
# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm()
def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer):
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
lhs_q = lhs_quantizer.quantize(
lhs,
is_rowwise=lhs_is_rowwise,
is_colwise=not lhs_is_rowwise,
)
rhs_q = rhs_quantizer.quantize(
rhs,
is_rowwise=rhs_is_rowwise,
is_colwise=not rhs_is_rowwise,
)
return lhs_q, rhs_q
# E5M2 * E5M2 is not supported
fwd_bwd_dtypes = [
[jnp.float8_e4m3fn, jnp.float8_e4m3fn],
......@@ -1154,219 +1214,217 @@ fwd_bwd_dtypes = [
[jnp.float8_e5m2, jnp.float8_e4m3fn],
]
"""
@pytest_parametrize_wrapper(
"shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]]
)
GROUPED_DENSE_INPUT_SHAPES = [
# (n_groups, m, n, k), the actual m will be multiplied by 32
(5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4
(8, 64, 32, 128),
(8, 64, 128, 256),
]
@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES)
class TestGroupedDense:
def _ref_grouped_gemm_with_jnp_dot(self, lhs_list, rhs_list, contracting_dims_list):
ref_out_list = []
for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list):
dim_nums = (contracting_dims, ((), ()))
ref_out_list.append(jax.lax.dot_general(lhs, rhs, dim_nums))
return ref_out_list
def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list):
def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims):
lhs_contract_dim, _ = contracting_dims
assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3
if bias is None:
bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype)
else:
assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2])
remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop()
lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis)
rhs = jnp.split(rhs, rhs.shape[0], axis=0)
bias = jnp.split(bias, bias.shape[0], axis=0)
ref_out = []
dim_num = (contracting_dims, ((), ()))
for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0)
ref_out.append(jnp.squeeze(out_i))
return ref_out
def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, len(shape_list) * 2)
lhs_list, rhs_list, contracting_dims_list = [], [], []
for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)):
lhs = jax.random.uniform(
subkeys[2 * i],
(m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=dtype,
)
rhs = jax.random.uniform(
subkeys[2 * i + 1],
(k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=dtype,
)
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
subkeys = jax.random.split(key, 4)
n_groups, m, n, k = input_shape
group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
group_sizes = jnp.diff(group_sizes)
assert group_sizes.sum() == m
# *32 to make sure that input shape works for MXFP8
group_sizes = group_sizes * 32
m = m * 32
lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m)
rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k)
bias_shape = (n_groups, n)
lhs_list.append(lhs)
rhs_list.append(rhs)
contracting_dims_list.append(contracting_dims)
lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype)
rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype)
bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
return lhs_list, rhs_list, contracting_dims_list
return lhs, rhs, group_sizes, contracting_dims, bias
def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype):
assert out.dtype == ref_list[0].dtype
out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
for i in range(len(ref_list)):
assert_allclose(out_list[i], ref_list[i], dtype=dtype)
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
@pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]])
def test_grouped_gemm_fp16(self, dtype, shape_list, layout_list):
lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input(
dtype, shape_list, layout_list
@pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
dtype, input_shape, layout
)
ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
primitive_out = tex.grouped_gemm(lhs_list, rhs_list, contracting_dims_list)
for i in range(len(shape_list)):
assert_allclose(primitive_out[i], ref_out[i], dtype=dtype)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
# grouped_gemm does not work with cudaGraph yet, so the jitting will breaks
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
# disable cudaGraph, then use the following jitted function
# jitting grouped_gemm
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
# lhs, rhs, group_sizes, contracting_dims,
# )
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list, layout_list):
@pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_gemm yet")
fwd_dtype, bwd_dtype = fwd_bwd_dtype
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False
scaling_mode=scaling_mode,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
is_2x2x=False,
n_groups=input_shape[0],
)
# quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype
# We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype
quantizer_set.kernel.q_dtype = bwd_dtype
for quantizer in quantizer_set.kernel.quantizers:
quantizer.q_dtype = bwd_dtype
out_dtype = jnp.bfloat16
lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input(
out_dtype, shape_list, layout_list
lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
out_dtype, input_shape, layout
)
q_lhs_list = []
q_rhs_list = []
for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list):
# quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to
# test the case where lhs and rhs have different q_dtypes
q_lhs, q_rhs = _quantize_gemm_pair(
lhs, rhs, contracting_dims, quantizer_set.x, quantizer_set.dgrad
)
q_lhs_list.append(q_lhs)
q_rhs_list.append(q_rhs)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
# jitting grouped_gemm
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))(
# lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
# )
ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
primitive_out = tex.grouped_gemm(q_lhs_list, q_rhs_list, contracting_dims_list)
prim_out = tex.grouped_gemm(
lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
)
allclose_dtype = jnp.float8_e4m3fn
if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2:
if jnp.float8_e5m2 in fwd_bwd_dtype:
allclose_dtype = jnp.float8_e5m2
for i in range(len(shape_list)):
assert_allclose(primitive_out[i], ref_out[i], dtype=allclose_dtype)
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
def test_grouped_dense_grad_fp16(self, dtype, shape_list):
group_size = len(shape_list)
layout_list = ["NN" for _ in range(group_size)]
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype)
def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims):
out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input(
dtype, shape_list, layout_list
def _primitive_sum_grouped_dense(
self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set
):
out = grouped_dense(
x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
)
bias_list = []
key = jax.random.PRNGKey(1)
for shape in shape_list:
n = shape[1]
bias = jax.random.uniform(key, n, dtype=dtype)
bias_list.append(bias)
def ref_func(x_list, kernel_list, bias_list, contracting_dims_list):
out_list = []
for i in range(len(x_list)):
out_list.append(
dense(
x_list[i],
kernel_list[i],
bias_list[i],
contracting_dims=contracting_dims_list[i],
)
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
return jnp.sum(jnp.asarray(out))
def primitive_func(x_list, kernel_list, bias_list, contracting_dims_list):
out_list = grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list)
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
def test_grouped_dense_grad_fp16(self, dtype, input_shape):
x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
dtype,
input_shape,
with_bias=True,
)
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
# jitting the grouped_dense
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
# static_argnums=(4,))
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func(
x_list, kernel_list, bias_list, contracting_dims_list
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
x, kernel, bias, group_sizes, contracting_dims
)
primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = (
value_n_grad_primitive_func(x_list, kernel_list, bias_list, contracting_dims_list)
prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
x, kernel, bias, group_sizes, contracting_dims
)
assert_allclose(primitive_out_mean, ref_out_mean, dtype=dtype)
for i in range(group_size):
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=dtype)
assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype)
assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype)
assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype)
assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize(
"fwd_bwd_dtype",
[(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list):
group_size = len(shape_list)
layout_list = ["NN" for _ in range(group_size)]
fwd_dtype, bwd_dtype = fwd_bwd_dtype
if fwd_dtype == jnp.float8_e5m2:
pytest.skip("We never use E5M2 for fwd_dtype in training")
# Question: should we use different quantizers for different groups?
ref_quantizer_set_list = []
quantizer_set_list = []
for _ in range(group_size):
ref_quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
)
ref_quantizer_set_list.append(ref_quantizer_set)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
)
quantizer_set_list.append(quantizer_set)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_dense yet")
out_dtype = jnp.bfloat16
x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input(
out_dtype, shape_list, layout_list
fwd_dtype, bwd_dtype = fwd_bwd_dtype
dtype = jnp.bfloat16
x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
dtype,
input_shape,
with_bias=True,
)
bias_list = []
key = jax.random.PRNGKey(1)
for shape in shape_list:
n = shape[1]
bias = jax.random.uniform(key, n, dtype=out_dtype)
bias_list.append(bias)
def ref_func(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
out_list = []
for i in range(len(x_list)):
out_list.append(
dense(
x_list[i],
kernel_list[i],
bias_list[i],
contracting_dims=contracting_dims_list[i],
quantizer_set=quantizer_set_list[i],
)
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
def primitive_func(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
):
out_list = grouped_dense(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
)
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func(
x_list, kernel_list, bias_list, contracting_dims_list, ref_quantizer_set_list
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
is_2x2x=True,
n_groups=group_sizes.size,
)
primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = (
value_n_grad_primitive_func(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
)
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
# jitting the grouped_dense
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
# static_argnums=(4,))
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
x,
kernel,
bias,
group_sizes,
contracting_dims,
)
prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set
)
allclose_dtype = jnp.float8_e4m3fn
if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2:
allclose_dtype = jnp.float8_e5m2
assert_allclose(primitive_out_mean, ref_out_mean, dtype=allclose_dtype)
for i in range(group_size):
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
"""
assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype)
assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
......@@ -68,6 +68,7 @@ class TestDistributedSelfAttn:
batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(
is_training,
dtype,
dtype,
QKVLayout.BS3HD,
......@@ -79,6 +80,7 @@ class TestDistributedSelfAttn:
seqlen,
seqlen,
hidden,
hidden,
None, # no window
):
pytest.skip("No FusedAttn backend found")
......@@ -98,6 +100,7 @@ class TestDistributedSelfAttn:
num_head,
num_head,
hidden,
hidden,
attn_bias_type,
attn_mask_type,
dropout_prob,
......@@ -214,6 +217,7 @@ class TestDistributedCrossAttn:
batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(
is_training,
dtype,
dtype,
QKVLayout.BSHD_BS2HD,
......@@ -225,6 +229,7 @@ class TestDistributedCrossAttn:
seqlen,
seqlen,
hidden,
hidden,
None, # no window
):
pytest.skip("No FusedAttn backend found")
......@@ -237,6 +242,7 @@ class TestDistributedCrossAttn:
num_head,
num_head,
hidden,
hidden,
attn_bias_type,
attn_mask_type,
dropout_prob,
......@@ -289,6 +295,7 @@ class TestDistributedContextParallelSelfAttn:
cp_strategy,
use_shardy,
use_scan_ring=False,
window_size=None,
):
if qkv_layout.is_thd():
if cp_strategy == CPStrategy.ALL_GATHER:
......@@ -326,6 +333,7 @@ class TestDistributedContextParallelSelfAttn:
num_head,
num_kv_heads,
hidden,
hidden,
attn_bias_type,
attn_mask_type,
dropout_prob,
......@@ -333,7 +341,7 @@ class TestDistributedContextParallelSelfAttn:
is_training,
qkv_layout,
bias_shape,
None,
window_size,
SeqDescFormat.SegmentIDs,
number_of_devices=device_count,
mesh_shape=mesh_shape,
......@@ -345,6 +353,7 @@ class TestDistributedContextParallelSelfAttn:
def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
is_training,
dtype,
dtype,
qkv_layout,
......@@ -356,6 +365,7 @@ class TestDistributedContextParallelSelfAttn:
seqlen,
seqlen,
hidden,
hidden,
None,
) # no SWA for CP
......@@ -476,6 +486,13 @@ class TestDistributedContextParallelSelfAttn:
"use_scan",
[pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
)
@pytest.mark.parametrize(
"window_size",
[
pytest.param((-1, -1), id="window_size(-1, -1)"),
pytest.param((20, 0), id="window_size(20, 0)"),
],
)
def test_context_parallel_ring_attn(
self,
device_count,
......@@ -489,7 +506,15 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
use_scan,
window_size,
):
if window_size != (-1, -1) and not qkv_layout.is_thd():
pytest.skip("Sliding window attention is only supported for THD layout")
if window_size != (-1, -1) and qkv_layout.is_thd() and use_scan:
pytest.skip(
"When context parallelism and sliding window attention are used, "
"scanloop is not supported"
)
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
......@@ -504,6 +529,7 @@ class TestDistributedContextParallelSelfAttn:
CPStrategy.RING,
use_shardy=False,
use_scan_ring=use_scan,
window_size=window_size,
)
@pytest.mark.parametrize(
......
......@@ -106,7 +106,8 @@ def general_dot_product_attention(
softmax_out = softmax_out * multiplier
context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
context = jnp.reshape(context, query.shape)
context_shape = query.shape[:-1] + (value.shape[-1],)
context = jnp.reshape(context, context_shape)
return context
......@@ -294,7 +295,8 @@ class FusedAttnRunner:
max_seqlen_kv: int
num_heads_q: int
num_heads_kv: int
head_dim: int
head_dim_qk: int
head_dim_v: int
attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType
dropout_prob: float
......@@ -346,7 +348,16 @@ class FusedAttnRunner:
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
pytest.skip(
"For head_dim_qk != head_dim_v, it is necessary that the QKV layout "
"is either BSHD_BSHD_BSHD or THD_THD_THD"
)
self.backend = FusedAttnHelper(
self.is_training,
self.dtype,
self.dtype,
self.qkv_layout,
......@@ -357,7 +368,8 @@ class FusedAttnRunner:
self.num_heads_kv,
self.max_seqlen_q,
self.max_seqlen_kv,
self.head_dim,
self.head_dim_qk,
self.head_dim_v,
(-1, -1) if self.window_size is None else self.window_size,
).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
......@@ -390,13 +402,9 @@ class FusedAttnRunner:
key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
k_shape = v_shape = (
self.batch_size,
self.max_seqlen_kv,
self.num_heads_kv,
self.head_dim,
)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_v)
if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None
......@@ -615,7 +623,7 @@ class FusedAttnRunner:
raise ValueError(f"Unknown {self.seq_desc_format=}")
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim)
self.scaling_factor = 1.0 / sqrt(self.head_dim_qk)
# Setup distributed sharding specs
# Setup shardings for distributed tests
......@@ -934,9 +942,31 @@ class FusedAttnRunner:
],
)
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d, dtype",
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype",
[
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(
2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF"
),
pytest.param(
2,
2048,
1024,
12,
12,
64,
64,
jnp.bfloat16,
id="2-2048-1024-12-12-64-64-BF16-CROSS",
),
pytest.param(
2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
),
pytest.param(
4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF"
),
pytest.param(
4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF"
),
pytest.param(
2,
2048,
......@@ -944,11 +974,13 @@ class FusedAttnRunner:
12,
12,
64,
32,
jnp.bfloat16,
id="2-2048-1024-12-12-64-BF16-CROSS",
id="2-2048-1024-12-12-64-32-BF16-CROSS",
),
pytest.param(
2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
],
)
@pytest.mark.parametrize(
......@@ -1002,7 +1034,8 @@ class TestFusedAttn:
s_kv,
h_q,
h_kv,
d,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
dropout_prob,
......@@ -1027,7 +1060,8 @@ class TestFusedAttn:
s_kv,
h_q,
h_kv,
d,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
dropout_prob,
......@@ -1054,7 +1088,8 @@ class TestFusedAttn:
s_kv,
h_q,
h_kv,
d,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
dropout_prob,
......@@ -1076,7 +1111,8 @@ class TestFusedAttn:
s_kv,
h_q,
h_kv,
d,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
dropout_prob,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
def pytest_addoption(parser):
parser.addoption(
"--feature_dirs", nargs="+", action="store", default="", help="List of feature directories"
)
parser.addoption(
"--configs_dir",
action="store",
default="",
type=str,
help="Path to the directory with configs.",
)
@pytest.fixture
def feature_dirs(request):
return request.config.getoption("--feature_dirs")
@pytest.fixture
def configs_dir(request):
return request.config.getoption("--configs_dir")
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import tempfile
import functools
import os
import itertools
import random
import argparse
import re
import torch
import torch.distributed as dist
import transformer_engine
import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from test_numerics import (
_emulate_linear,
_init_debug,
disable_fp8_gemms_create_config,
DISABLE_FP8_LAYER_CONFIG,
_cmp,
IN_SIZE,
OUT_SIZE,
_init_model,
SEED,
SEQ_LEN,
BATCH_SIZE,
FP8_RECIPE,
fake_quant_fp8_create_config,
_get_current_scale,
_prepare_per_tensor_scaling_config,
AMAX_HISTORY_LEN,
set_scaling_factors,
set_current_scaling_factors,
)
WORLD_RANK, WORLD_SIZE = None, None
NCCL_WORLD = None
FEATURE_DIRS = None
all_boolean = [True, False]
TEST_NR = 0
def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
if tp_size is None:
tp_size = WORLD_SIZE
tp_rank = WORLD_RANK
torch.manual_seed(weight_seed)
weight = torch.randn((OUT_SIZE, IN_SIZE)).cuda()
torch.manual_seed(data_seed)
in_split_size = IN_SIZE // tp_size
out_split_size = OUT_SIZE // tp_size
x = torch.randn((SEQ_LEN * BATCH_SIZE, IN_SIZE), requires_grad=True).cuda()
if parallel_mode == "row":
x = x[:, tp_rank * in_split_size : (tp_rank + 1) * in_split_size]
x.retain_grad()
with torch.no_grad():
if parallel_mode == "column":
weight = weight[tp_rank * out_split_size : (tp_rank + 1) * out_split_size, :]
else:
weight = weight[:, tp_rank * in_split_size : (tp_rank + 1) * in_split_size]
return x, weight.contiguous()
def _init_model(weight, parallel_mode=None, tp_group=None, name="linear"):
model = transformer_engine.pytorch.Linear(
IN_SIZE,
OUT_SIZE,
name=name,
parallel_mode=parallel_mode,
tp_group=(tp_group or NCCL_WORLD if parallel_mode else None),
)
with torch.no_grad():
model.weight.copy_(weight)
return model
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, dim, group=None):
if group is None:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
else:
world_size = torch.distributed.get_world_size(group=group)
rank = torch.distributed.get_rank(group=group)
dist.barrier()
# Create a list to gather tensors from all processes
y_list = [torch.zeros_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(y_list, tensor, group=group)
# Save the world size and rank for backward computation
ctx.world_size = world_size
ctx.rank = rank
ctx.dim = dim
# Concatenate the gathered tensors along the feature dimension
y_full = torch.cat(y_list, dim=dim)
return y_full
@staticmethod
def backward(ctx, grad_output):
# Split the gradient output and return the portion corresponding to this rank
grad_input = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)[ctx.rank]
return grad_input, None, None
def _run_forward_backward(x, model, parallel_mode=None, group=None):
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y = model(x)
y.requires_grad_(True)
y.retain_grad()
if parallel_mode == "column":
y = AllGather.apply(y, -1, group)
y.requires_grad_(True)
y.retain_grad()
l = y.sum()
l.backward()
elif parallel_mode == "row":
l = y.sum()
l.backward()
debug_api.step()
return y
def _emulate_linear_distributed(*args, parallel_mode=None, **kwargs):
assert parallel_mode in ["column", "row"]
def split(gradient):
split_size = OUT_SIZE // WORLD_SIZE
gradient = gradient[:, WORLD_RANK * split_size : (WORLD_RANK + 1) * split_size]
return gradient
activation_sync = None
gradient_sync = None
if parallel_mode == "column":
activation_sync = lambda x: AllGather.apply(x, -1)
gradient_sync = split
else:
activation_sync = (
lambda activation: dist.all_reduce(activation, op=dist.ReduceOp.SUM) or activation
)
output = _emulate_linear(
*args, activation_sync=activation_sync, gradient_sync=gradient_sync, **kwargs
)
if parallel_mode == "column":
dist.all_reduce(output["dgrad"], op=dist.ReduceOp.SUM)
return output
def check_debug_log(msg):
with open(f"log/debug_logs/debug_log_globalrank-{WORLD_RANK}.log", "r") as f:
for line in f.readlines():
if msg in line:
return True
return False
def run_debug_test(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank = dist.get_rank()
temp_file_name = None
temp_logdir_name = None
if rank == 0:
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
temp_file_name = temp_file.name
temp_dir_obj = tempfile.TemporaryDirectory()
temp_logdir_name = temp_dir_obj.name
# Store the TemporaryDirectory object to prevent it from being deleted
wrapper.temp_dir_obj = temp_dir_obj
temp_file_name_list = [temp_file_name]
temp_logdir_name_list = [temp_logdir_name]
# Broadcast the temporary file and directory names to all processes
dist.broadcast_object_list(temp_file_name_list, src=0)
dist.broadcast_object_list(temp_logdir_name_list, src=0)
temp_file_name = temp_file_name_list[0]
temp_logdir_name = temp_logdir_name_list[0]
dist.barrier()
config_file = open(temp_file_name, mode="r+", buffering=1)
try:
kwargs["config_file"] = config_file
kwargs["log_dir"] = temp_logdir_name
if rank == 0:
global TEST_NR
print(f"Running test {TEST_NR} {func.__name__} with args = {args}.")
TEST_NR += 1
func(*args, **kwargs)
finally:
if rank == 0 and temp_file_name is not None:
os.unlink(temp_file_name)
debug_api.end_debug()
if rank == 0 and hasattr(wrapper, "temp_dir_obj"):
wrapper.temp_dir_obj.cleanup()
return wrapper
CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
LogFp8TensorStats:
enabled: True
tensors: [activation, gradient, weight]
stats: [underflows%]
start_step : 0
end_step: 1
"""
def _prepare_config_test_log_distributed(config_file):
if WORLD_RANK != 0:
return
config_file.write(CONFIG_LOG_TEST_DISTRIBUTED)
config_file.flush()
def _compute_dynamic_range(tensor):
tensor_abs = tensor.abs()
tensor_abs = tensor_abs[tensor_abs != 0]
if tensor_abs.any():
amin = tensor_abs.min().float()
else:
amin = torch.tensor(1, device=tensor.device).to(torch.float)
amax = tensor_abs.max().float()
if not amax.all():
amax = torch.tensor(1, device=tensor.device).to(torch.float)
dynamic_range = torch.log2(amax) - torch.log2(amin)
return dynamic_range
@run_debug_test
def test_log_distributed(parallel_mode, gather_weight, **kwargs):
_prepare_config_test_log_distributed(kwargs["config_file"])
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
set_weight_tensor_tp_group_reduce(gather_weight)
if WORLD_SIZE % 2 != 0:
return # skip
TP_SIZE = WORLD_SIZE // 2
DP_SIZE = 2
TP_RANK = WORLD_RANK % TP_SIZE
DP_RANK = (WORLD_RANK - TP_RANK) // TP_SIZE
debug_api.set_tensor_reduction_group(NCCL_WORLD)
x, weight = _get_tensors(
parallel_mode,
weight_seed=TP_RANK * 1234,
data_seed=DP_RANK * 1234,
tp_size=TP_SIZE,
tp_rank=TP_RANK,
)
tp_group_ranks = [i for i in range(DP_RANK * TP_SIZE, (DP_RANK + 1) * TP_SIZE)]
tp_group = dist.new_group(ranks=tp_group_ranks)
dp_group_ranks = [i for i in range(TP_RANK, WORLD_SIZE, TP_SIZE)]
dp_group = dist.new_group(ranks=dp_group_ranks)
model = _init_model(weight, parallel_mode=parallel_mode, tp_group=tp_group)
output = _run_forward_backward(x, model, parallel_mode=parallel_mode, group=tp_group)
gathered_activation = AllGather.apply(x.contiguous(), 0)
gathered_weight = AllGather.apply(weight.contiguous(), 0, tp_group)
gathered_gradient = AllGather.apply(output.grad.contiguous(), 0, dp_group)
if parallel_mode == "row":
gathered_gradient = AllGather.apply(gathered_gradient, 0, tp_group)
log_file = kwargs["log_dir"] + "/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log"
dist.barrier()
if WORLD_RANK != 0:
return # stats are gathered on node 0
with open(log_file) as f:
content = f.read()
def get_stat(tensor, stat):
regex = r".*_{tensor}_{stat}\s+.*iteration=(\d+)\s+.*value=([-+]?\d*\.?\d+)".format(
tensor=tensor, stat=stat
)
for line in content.splitlines():
match = re.search(regex, line)
if match:
value = float(match.group(2))
return value
rf = lambda x: round(float(x), 4)
stats = []
tensors = {
"activation": gathered_activation,
"weight": gathered_weight if gather_weight else weight,
"gradient": gathered_gradient,
}
stats = {
"min": torch.min,
"max": torch.max,
"mean": torch.mean,
"std": torch.std,
"l1_norm": lambda x: torch.norm(x, p=1),
"l2_norm": lambda x: torch.norm(x, p=2),
"cur_amax": lambda x: x.abs().max(),
"dynamic_range": _compute_dynamic_range,
}
for stat_key in stats.keys():
for tensor_key in tensors.keys():
torch.testing.assert_close(
get_stat(tensor_key, stat_key),
rf(stats[stat_key](tensors[tensor_key])),
atol=0.0001,
rtol=0.0001,
)
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def test_log_expert_parallel(**kwargs):
"""
This test tests the scenario, when one of the node of data parallel does not invoke the debug layer.
It naturally occurs in the expert parallelism, when one expert doesn't get input on one node,
but gets it on other nodes. If there were all_gather inside forward(), this would result in deadlock.
"""
_prepare_config_test_log_distributed(kwargs["config_file"])
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
debug_api.set_tensor_reduction_group(NCCL_WORLD)
x, weight = _get_tensors(
"row", weight_seed=WORLD_RANK * 1234, data_seed=WORLD_RANK * 1234, tp_size=1, tp_rank=0
) # data parallel
model = _init_model(weight, parallel_mode=None, name="linear1")
model1 = _init_model(weight, parallel_mode=None, name="linear2")
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y1 = model(x)
y2 = model1(x)
y = y1 + y2
y.sum().backward()
debug_api.step()
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y = model(x)
if WORLD_RANK != 0:
y = y + model1(x)
y.sum().backward()
@run_debug_test
def test_disable_fp8_gemms(fprop_fp8, dgrad_fp8, wgrad_fp8, parallel_mode, **kwargs):
disable_fp8_gemms_create_config(fprop_fp8, dgrad_fp8, wgrad_fp8, kwargs["config_file"])
fp8_kwargs = {
"fprop_fp8": fprop_fp8,
"dgrad_fp8": dgrad_fp8,
"wgrad_fp8": wgrad_fp8,
}
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
x, weight = _get_tensors(parallel_mode)
model = _init_model(weight, parallel_mode=parallel_mode)
y = _run_forward_backward(x, model, parallel_mode=parallel_mode)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
x.grad.zero_()
ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs)
_cmp(ground_truth, output)
@run_debug_test
def test_disable_fp8_layer(parallel_mode, **kwargs):
if WORLD_RANK == 0:
kwargs["config_file"].write(DISABLE_FP8_LAYER_CONFIG)
kwargs["config_file"].flush()
dist.barrier()
x, weight = _get_tensors(parallel_mode)
ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode)
x.grad.zero_()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
model = _init_model(weight, parallel_mode)
y = _run_forward_backward(x, model, parallel_mode)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
_cmp(ground_truth, output)
@run_debug_test
def test_per_tensor_scaling(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
parallel_mode,
**kwargs,
):
input_kwargs = {
"fprop_inp": fprop_inp,
"fprop_weight": fprop_weight,
"dgrad_weight": dgrad_weight,
"dgrad_grad": dgrad_grad,
"wgrad_input": wgrad_input,
"wgrad_grad": wgrad_grad,
}
fp8_kwargs = {
"fprop_fp8": True,
"dgrad_fp8": True,
"wgrad_fp8": True,
}
"""
Runs a test to validate per-tensor (current) scaling in FP8 computations.
The function performs warm-up iterations to populate the amax buffer of the model and compute scaling factors based on delayed scaling.
Subsequently, weights and inputs are switched to ensure their current scaling factors differ from those based on delayed scaling;
similarly, the loss is multiplied by a large factor to alter the gradient's magnitude,
creating a discrepancy between the original (delayed) and per-tensor (current) scaling factors.
Finally, a linear pass is emulated, and the results are compared.”
"""
_prepare_per_tensor_scaling_config(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
kwargs["config_file"],
)
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
warmup_input, warmup_weight = _get_tensors(parallel_mode=parallel_mode)
model = _init_model(warmup_weight, parallel_mode=parallel_mode)
# Warmup run to setup amax and scaling factors.
for _ in range(AMAX_HISTORY_LEN):
_run_forward_backward(warmup_input, model, parallel_mode=parallel_mode)
x, weight = _get_tensors(
parallel_mode=parallel_mode, weight_seed=WORLD_RANK * 2137, data_seed=WORLD_RANK * 2137
)
model.weight.data = weight.data
x.retain_grad()
# delayed scaling factor
# need to be collected before forward pass with test data,
# because this forward pass changes scaling factors
set_scaling_factors(model, input_kwargs, fp8_kwargs)
LOSS_MULTIPLIER = 100
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y = model(x)
model.zero_grad()
if parallel_mode == "column":
y = AllGather.apply(y, -1)
y.retain_grad()
(
LOSS_MULTIPLIER * y.sum()
).backward() # Loss multiplication to change gradient's order of magintude
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
# per tensor - current - scaling factors
# need to be collected after forward pass with test data,
# because gradient(y.grad) cannot be accessed before forward,
# but it needs to be collected.
set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs)
ground_truth = _emulate_linear_distributed(
x, weight, parallel_mode=parallel_mode, loss_multiplier=LOSS_MULTIPLIER, **fp8_kwargs
)
_cmp(ground_truth, output)
@run_debug_test
def test_fake_quant_fp8(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
parallel_mode,
**kwargs,
):
fp8_kwargs = {
"fprop_input_fake_quant": fprop_inp,
"fprop_weight_fake_quant": fprop_weight,
"dgrad_gradient_fake_quant": dgrad_grad,
"dgrad_weight_fake_quant": dgrad_weight,
"wgrad_gradient_fake_quant": wgrad_grad,
"wgrad_input_fake_quant": wgrad_input,
"fprop_fp8": not (fprop_inp or fprop_weight),
"dgrad_fp8": not (dgrad_weight or dgrad_grad),
"wgrad_fp8": not (wgrad_grad or wgrad_input),
}
if WORLD_RANK == 0:
fake_quant_fp8_create_config(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
kwargs["config_file"],
)
dist.barrier()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
x, weight = _get_tensors(parallel_mode)
model = _init_model(weight, parallel_mode)
y = _run_forward_backward(x, model, parallel_mode)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
fp8_kwargs["fprop_input_scale"] = (
_get_current_scale(x, fprop_inp) if not fp8_kwargs["fprop_fp8"] else None
)
fp8_kwargs["fprop_weight_scale"] = (
_get_current_scale(weight, fprop_weight) if not fp8_kwargs["fprop_fp8"] else None
)
fp8_kwargs["dgrad_gradient_scale"] = (
_get_current_scale(y.grad, dgrad_grad) if not fp8_kwargs["dgrad_fp8"] else None
)
fp8_kwargs["dgrad_weight_scale"] = (
_get_current_scale(weight, dgrad_weight) if not fp8_kwargs["dgrad_fp8"] else None
)
fp8_kwargs["wgrad_gradient_scale"] = (
_get_current_scale(y.grad, wgrad_grad) if not fp8_kwargs["wgrad_fp8"] else None
)
fp8_kwargs["wgrad_input_scale"] = (
_get_current_scale(x, wgrad_input) if not fp8_kwargs["wgrad_fp8"] else None
)
ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs)
_cmp(ground_truth, output)
def _init_distributed():
global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, FP8
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node
assert LOCAL_SIZE <= torch.cuda.device_count()
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
dist_init_kwargs["init_method"] = "env://"
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
torch.cuda.set_device(LOCAL_RANK)
dist.init_process_group(**dist_init_kwargs)
NCCL_WORLD = dist.new_group(backend="nccl")
WORLD_SIZE = dist.get_world_size()
def _run_test_with_combinations(
test_function, values_list, num_repeat, extra_args, sample_size=None
):
combinations = itertools.product(values_list, repeat=num_repeat)
total_combinations = itertools.product(combinations, extra_args)
if sample_size is not None:
total_combinations = random.sample(list(total_combinations), sample_size)
for comb, arg in total_combinations:
test_function(*comb, arg)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--feature_dirs", type=str)
args = parser.parse_args()
FEATURE_DIRS = args.feature_dirs
random.seed(SEED)
_init_distributed()
test_log_expert_parallel()
for parallel_mode in ["column", "row"]:
for gather_weight in [True, False]:
test_log_distributed(parallel_mode, gather_weight)
for parallel_mode in ["row", "column"]:
test_disable_fp8_layer(parallel_mode)
# test_disable_fp8_gemms
_run_test_with_combinations(
test_disable_fp8_gemms, all_boolean, num_repeat=3, extra_args=["column", "row"]
)
# test_fake_quant_fp8
dtype_options = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None]
_run_test_with_combinations(
test_fake_quant_fp8,
dtype_options,
num_repeat=6,
extra_args=["column", "row"],
sample_size=20,
)
_run_test_with_combinations(
test_per_tensor_scaling,
all_boolean,
num_repeat=6,
extra_args=["column"],
sample_size=20,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
import nvdlfw_inspect.api as debug_api
try:
import transformer_engine
import transformer_engine_torch as tex
except (ImportError, ModuleNotFoundError):
print("Could not find TransformerEngine package.")
exit(1)
def test_transformer_engine_no_config(feature_dirs):
debug_api.initialize("", feature_dirs=feature_dirs)
try:
tensor = torch.rand(24, 2046).cuda()
# FP8 enabled - true by the default
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
# modify_tensor_enabled - False by default
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
# inspect_tensor_enabled - False by default
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.attn.qkv", tensor_name="activation", iteration=0
)
# inspect_tensor_postquantize - False by default
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
finally:
debug_api.end_debug()
def test_disable_fp8_gemm(configs_dir, feature_dirs):
try:
debug_api.initialize(configs_dir + "disable_fp8_gemms.yaml", feature_dirs=feature_dirs)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
# caching
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
finally:
debug_api.end_debug()
def test_disable_fp8_layer(configs_dir, feature_dirs):
try:
debug_api.initialize(configs_dir + "disable_fp8_layer.yaml", feature_dirs=feature_dirs)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="fprop", iteration=0
)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", iteration=0
)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
finally:
debug_api.end_debug()
def test_per_tensor_scaling(configs_dir, feature_dirs):
try:
debug_api.initialize(configs_dir + "per_tensor_scaling.yaml", feature_dirs=feature_dirs)
tensor = torch.rand(24, 2046).cuda()
# check modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
)
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0
)
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0
)
# check modify_tensor
default_quantizer1 = Float8Quantizer(
scale=torch.tensor([1]).cuda(),
amax=torch.tensor([0]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
default_quantizer2 = Float8Quantizer(
scale=torch.tensor([1]).cuda(),
amax=torch.tensor([0]).cuda(),
fp8_dtype=tex.DType.kFloat8E5M2,
)
output1 = debug_api.transformer_engine.modify_tensor(
layer_name="decoder.1.mlp.fc1",
gemm="fprop",
tensor_name="activation",
default_quantizer=default_quantizer1,
iteration=0,
tensor=tensor,
)
assert type(output1) == Float8Tensor
assert output1._fp8_dtype == tex.DType.kFloat8E4M3
output2 = debug_api.transformer_engine.modify_tensor(
"decoder.1.mlp.fc1",
gemm="dgrad",
tensor=tensor,
tensor_name="gradient",
default_quantizer=default_quantizer2,
iteration=0,
)
assert type(output2) == Float8Tensor
assert output2._fp8_dtype == tex.DType.kFloat8E5M2
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1",
gemm="wgrad",
tensor_name="gradient",
iteration=0,
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc4",
gemm="fprop",
tensor_name="activation",
iteration=0,
)
finally:
debug_api.end_debug()
def test_fake_quant(configs_dir, feature_dirs):
try:
debug_api.initialize(
configs_dir + "fake_quantization_config.yaml", feature_dirs=feature_dirs
)
tensor = torch.rand(24, 2046).cuda()
# modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
)
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
)
# modify_tensor
debug_api.transformer_engine.modify_tensor(
"decoder.1.mlp.fc1",
gemm="fprop",
tensor=tensor,
tensor_name="activation",
iteration=0,
default_quantizer=None,
)
debug_api.transformer_engine.modify_tensor(
"decoder.1.mlp.fc1",
gemm="dgrad",
tensor=tensor,
tensor_name="gradient",
iteration=0,
default_quantizer=None,
)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0
)
# caching
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0
)
finally:
debug_api.end_debug()
def test_statistics_collection(configs_dir, feature_dirs):
try:
debug_api.initialize(
config_file=configs_dir + "stats_collection_test_config.yaml",
feature_dirs=feature_dirs,
default_logging_enabled=False,
)
tensor = torch.randn((100, 100, 5)).cuda()
tensor_fp8 = Float8Tensor(
data=tensor.to(torch.uint8).cuda(),
fp8_scale_inv=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=tensor.shape,
dtype=torch.float32,
)
def log():
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
return STATS_BUFFERS.log_stats()
def assert_empty():
stats = log()
assert len(stats) == 0
# TE tensor stats --
debug_api.transformer_engine.inspect_tensor(
"decoder.1.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=200,
tp_group=None,
)
stats = log()
assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max()
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", iteration=201
)
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="activation", iteration=200
)
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)
expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5)
expected_overflows = (tensor_fp8._data == 126).sum() * 100 / (100 * 100 * 5)
# TE FP8 tensor stats --
assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
)
debug_api.transformer_engine.inspect_tensor_postquantize(
"decoder.1.mlp.fc1",
tensor=tensor_fp8,
tensor_name="gradient",
iteration=200,
rowwise=True,
tp_group=None,
)
stats = log()
torch.testing.assert_close(
stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows
)
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201
)
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
)
# Second config in same yaml
tensor = torch.rand((100, 100, 5))
debug_api.transformer_engine.inspect_tensor(
"decoder.6.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=200,
tp_group=None,
)
stats = log()
stats_names = [x[3] for x in stats.keys()]
all(s in stats_names for s in ["cur_amax", "dynamic_range", "mean", "std", "l1_norm"])
assert stats[("decoder.6.mlp.fc1", "activation", "mean", 200)] == tensor.mean()
debug_api.transformer_engine.inspect_tensor(
"decoder.7.mlp.fc1",
tensor=tensor,
tensor_name="weight",
iteration=200,
tp_group=None,
)
stats = log()
stats_names = [x[3] for x in stats.keys()]
all(s in stats_names for s in ["mean", "std", "l1_norm", "min", "max"])
assert stats[("decoder.7.mlp.fc1", "weight", "max", 200)] == tensor.max()
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.7.mlp.fc1", tensor_name="weight", iteration=201
)
assert_empty()
finally:
debug_api.end_debug()
def test_statistics_multi_run(configs_dir, feature_dirs):
try:
debug_api.initialize(
config_file=configs_dir + "stats_collection_test_config.yaml",
feature_dirs=feature_dirs,
default_logging_enabled=False,
)
def feed(tensor, tensor_fp8):
debug_api.transformer_engine.inspect_tensor(
"decoder.5.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=1,
tp_group=None,
)
debug_api.transformer_engine.inspect_tensor_postquantize(
"decoder.5.mlp.fc1",
tensor=tensor_fp8,
tensor_name="activation",
iteration=1,
rowwise=True,
tp_group=None,
)
def log_stats():
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
return STATS_BUFFERS.log_stats()
def fp8_tensor(t):
return Float8Tensor(
data=t.to(torch.uint8).cuda(),
fp8_scale_inv=torch.ones([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=t.shape,
dtype=torch.float32,
)
shape = [1024, 1024]
tensors = [torch.randn(shape) for _ in range(2)]
tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)]
feed(tensors[0], tensors_fp8[0])
feed(tensors[1], tensors_fp8[1])
stats1 = log_stats()
tensor2 = torch.cat((tensors[0], tensors[1])).cuda()
fp8tensor2 = fp8_tensor(tensor2)
feed(tensor2, fp8tensor2)
stats2 = log_stats()
assert len(stats1.keys()) > 0
for k in stats1.keys():
torch.testing.assert_close(stats1[k], stats2[k])
finally:
debug_api.end_debug()
if __name__ == "__main__":
pass
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib, os
from nvdlfw_inspect.config_manager import ConfigManager
import nvdlfw_inspect.api as debug_api
try:
import transformer_engine
from transformer_engine.debug.features.api import TEConfigAPIMapper
except (ImportError, ModuleNotFoundError):
print("Could not find TransformerEngine debug module.")
exit(1)
def test_transformer_engine_config_parsing(feature_dirs):
debug_api.initialize(
config_file=pathlib.Path(__file__).resolve().parent
/ "test_configs/tensor_manipulation_transformer_engine.yaml",
feature_dirs=feature_dirs,
log_dir="./log",
)
cfg_fc1 = ConfigManager.get_config_for_layer("decoder.1.mlp.fc1")["transformer_engine"]
cfg_fc2 = ConfigManager.get_config_for_layer("decoder.1.mlp.fc2")["transformer_engine"]
assert cfg_fc1 and cfg_fc2
gemm_parsing = True
tensor_parsing = True
# Per tensor scaling set for dgrad, filter based on gemm
ret, _ = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="wgrad",
tensor_name="activation",
)
assert not ret
# per tensor scaling set for gradient, filter based on tensor name
ret, _ = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="activation",
)
assert not ret
ret, parsed_cfg_fc1 = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="gradient",
)
assert ret
assert parsed_cfg_fc1 == {"gemm": "dgrad", "tensor": "gradient"}
# Test tensor struct
ret, parsed_cfg_fc1_act = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="activation",
)
ret, parsed_cfg_fc1_wei = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="weight",
)
assert ret
assert parsed_cfg_fc1_act == {
"gemm": "fprop",
"tensor": "activation",
"quant_format": "FP8E4M3",
}
assert parsed_cfg_fc1_wei == {
"gemm": "fprop",
"tensor": "weight",
"quant_format": "FP8E4M3",
}
# Test gemms struct
ret, parsed_cfg_fc2_grad = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="gradient",
)
assert ret
assert parsed_cfg_fc2_grad == {"gemm": "dgrad", "tensor": "gradient", "quant_format": "FP8E5M2"}
ret, parsed_cfg_fc2_wei = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="weight",
)
assert ret
assert parsed_cfg_fc2_wei == {"gemm": "dgrad", "tensor": "weight", "quant_format": "FP8E5M2"}
# Test gemm + tensor struct
ret, parsed_cfg_fc2_fprop_act = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="activation",
)
assert ret
assert parsed_cfg_fc2_fprop_act == {"gemm": "fprop", "tensor": "activation"}
ret, parsed_cfg_fc2_fprop_wei = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="weight",
)
assert ret
assert parsed_cfg_fc2_fprop_wei == {"gemm": "fprop", "tensor": "weight"}
ret, parsed_cfg_fc2_wgrad_act = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="wgrad",
tensor_name="activation",
)
assert ret
assert parsed_cfg_fc2_wgrad_act == {"gemm": "wgrad", "tensor": "activation"}
ret, parsed_cfg_fc2_wgrad_grad = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="wgrad",
tensor_name="gradient",
)
assert ret
assert parsed_cfg_fc2_wgrad_grad == {"gemm": "wgrad", "tensor": "gradient"}
ConfigManager.reset()
test_disable_fp8_gemm_1:
enabled: True
layers:
layer_types: [qkv, fc2]
transformer_engine:
DisableFP8GEMM:
enabled: True
gemms: [dgrad, wgrad]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment