Commit 544dd14b authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Update main branch with TE 2.0 code, update version to 2.1.0.dev0


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent e5369541
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
extension-pkg-whitelist=flash_attn_2_cuda, extension-pkg-whitelist=flash_attn_2_cuda,
torch, torch,
transformer_engine_torch, transformer_engine_torch,
transformer_engine_paddle,
transformer_engine_jax transformer_engine_jax
extension-pkg-allow-list=transformer_engine.transformer_engine_jax extension-pkg-allow-list=transformer_engine.transformer_engine_jax
......
...@@ -8,7 +8,7 @@ pip install "nltk>=3.8.2" ...@@ -8,7 +8,7 @@ pip install "nltk>=3.8.2"
pip install pytest==8.2.1 pip install pytest==8.2.1
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py
# Test without custom calls # Test without custom calls
NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
echo "Checking common API headers"
cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include
echo "Checking C++ files"
cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common
cpplint --recursive transformer_engine/paddle
fi
if [ -z "${CPP_ONLY}" ]
then
cd $TE_PATH
echo "Checking Python files"
pylint --recursive=y transformer_engine/common transformer_engine/paddle
fi
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -xe
pip install pytest==8.2.1
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/paddle
pytest -Wignore -v $TE_PATH/examples/paddle/mnist
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
# Install dependencies
# Note: Need to install wheel locally since PaddlePaddle container
# already contains APT install.
pip install pydantic
pip install --user wheel==0.44.0
cd $TE_PATH
pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle
VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
# Core wheel.
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
python -m wheel unpack dist/*
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
python -m wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel
pip install dist/*.whl --no-deps
cd transformer_engine/paddle
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
pip install dist/*
python $TE_PATH/tests/paddle/test_sanity_import.py
...@@ -11,11 +11,10 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py ...@@ -11,11 +11,10 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
......
...@@ -8,8 +8,8 @@ set -e ...@@ -8,8 +8,8 @@ set -e
pip install pytest==8.2.1 pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
# pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py ### TODO Debug UB support with te.Sequential
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==8.2.1 onnxruntime==1.19.2
# Build custom ONNX Runtime operators
export CUSTOM_ORT_OPS_PATH=$TE_PATH/tests/pytorch/custom_ort_ops
bash $CUSTOM_ORT_OPS_PATH/build.sh
# Run tests
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
...@@ -12,7 +12,14 @@ pip install pytest==8.2.1 ...@@ -12,7 +12,14 @@ pip install pytest==8.2.1
export MAX_JOBS=4 export MAX_JOBS=4
# Iterate over Flash Attention versions # Iterate over Flash Attention versions
FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.6.3 3.0.0b1) sm_arch=`python -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"`
if [ $sm_arch -gt 90 ]
then
FA_versions=(2.7.3)
else
FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
fi
for fa_version in "${FA_versions[@]}" for fa_version in "${FA_versions[@]}"
do do
...@@ -21,10 +28,10 @@ do ...@@ -21,10 +28,10 @@ do
then then
pip install flash-attn==${fa_version} pip install flash-attn==${fa_version}
else else
pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
python_path=`python -c "import site; print(site.getsitepackages()[0])"` python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flashattn_hopper mkdir -p $python_path/flashattn_hopper
wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py
fi fi
# Run tests # Run tests
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""Installation script.""" """Installation script."""
import os import os
import sys
import time import time
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import List, Tuple
...@@ -35,14 +36,13 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1" ...@@ -35,14 +36,13 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1"
if "pytorch" in frameworks: if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks: elif "jax" in frameworks:
install_and_import("pybind11[global]") install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension from pybind11.setup_helpers import build_ext as BuildExtension
CMakeBuildExtension = get_build_ext(BuildExtension) CMakeBuildExtension = get_build_ext(BuildExtension)
archs = cuda_archs()
class TimedBdist(bdist_wheel): class TimedBdist(bdist_wheel):
...@@ -57,7 +57,7 @@ class TimedBdist(bdist_wheel): ...@@ -57,7 +57,7 @@ class TimedBdist(bdist_wheel):
def setup_common_extension() -> CMakeExtension: def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library""" """Setup CMake extension for common library"""
cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())] cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert ( assert (
os.getenv("MPI_HOME") is not None os.getenv("MPI_HOME") is not None
...@@ -104,13 +104,11 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -104,13 +104,11 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
install_reqs.extend(["torch"]) install_reqs.extend(["torch"])
test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) test_reqs.extend(["numpy", "torchvision", "prettytable"])
if "jax" in frameworks: if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"]) install_reqs.extend(["jax", "flax>=0.7.1"])
test_reqs.extend(["numpy", "praxis"]) # test_reqs.extend(["numpy", "praxis"])
if "paddle" in frameworks: test_reqs.extend(["numpy"])
install_reqs.append("paddlepaddle-gpu")
test_reqs.append("numpy")
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
...@@ -135,7 +133,6 @@ if __name__ == "__main__": ...@@ -135,7 +133,6 @@ if __name__ == "__main__":
extras_require = { extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"],
"paddle": [f"transformer_engine_paddle=={__version__}"],
} }
else: else:
setup_requires, install_requires, test_requires = setup_requirements() setup_requires, install_requires, test_requires = setup_requirements()
...@@ -169,16 +166,6 @@ if __name__ == "__main__": ...@@ -169,16 +166,6 @@ if __name__ == "__main__":
current_file_path / "transformer_engine", current_file_path / "transformer_engine",
) )
) )
if "paddle" in frameworks:
from build_tools.paddle import setup_paddle_extension
ext_modules.append(
setup_paddle_extension(
"transformer_engine/paddle/csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc",
current_file_path / "transformer_engine",
)
)
# Configure package # Configure package
setuptools.setup( setuptools.setup(
......
...@@ -5,7 +5,11 @@ ...@@ -5,7 +5,11 @@
cmake_minimum_required(VERSION 3.18) cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 90) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90)
endif()
endif() endif()
......
...@@ -3,23 +3,33 @@ ...@@ -3,23 +3,33 @@
# See LICENSE for license information. # See LICENSE for license information.
add_executable(test_operator add_executable(test_operator
test_cast.cu
test_cast_dbias.cu
test_cast_dbias_dgelu.cu
test_cast_gated_swiglu.cu
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu test_qdq.cu
test_cast_transpose.cu test_cast_mxfp8.cu
test_dequantize_mxfp8.cu
test_transpose.cu test_transpose.cu
test_cast_transpose.cu
test_cast_transpose_dbias.cu test_cast_transpose_dbias.cu
test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dbias_dgelu.cu
test_cast_transpose_dgeglu.cu test_cast_transpose_dgeglu.cu
test_act.cu test_act.cu
test_normalization.cu test_normalization.cu
test_normalization_mxfp8.cu
test_multi_cast_transpose.cu test_multi_cast_transpose.cu
test_multi_padding.cu test_multi_padding.cu
test_causal_softmax.cu test_causal_softmax.cu
test_swizzle.cu
../test_common.cu) ../test_common.cu)
find_package(OpenMP REQUIRED)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS}) target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX)
target_compile_options(test_operator PRIVATE -O2) target_compile_options(test_operator PRIVATE -O2 -fopenmp)
include(GoogleTest) include(GoogleTest)
gtest_discover_tests(test_operator) gtest_discover_tests(test_operator DISCOVERY_TIMEOUT 600)
...@@ -21,58 +21,6 @@ ...@@ -21,58 +21,6 @@
using namespace transformer_engine; using namespace transformer_engine;
namespace {
// forward
float gelu(const float x) {
return 0.5f * x * (1.0f + tanhf(0.79788456F * x * (1.0f + 0.044715f * x * x)));
}
float silu(const float x) {
return x / (1 + expf(-x));
}
float relu(const float x) {
return x > 0 ? x : 0;
}
float srelu(const float x) {
return x > 0 ? x * x : 0;
}
float qgelu(const float x) {
return x / (1 + expf(-1.702f * x));
}
// backward
float dgelu(const float x) {
const float tanh_out = tanhf(0.79788456f * x * (1.f + 0.044715f * x * x));
return 0.5f * x * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) +
0.5f * (1.f + tanh_out);
}
float dsilu(const float x) {
const float sigmoid = 1.f / (1 + expf(-x));
return x * sigmoid * (1.f - sigmoid) + sigmoid;
}
float drelu(const float x) {
return x > 0.f ? 1.f : 0.f;
}
float dsrelu(const float x) {
return fmaxf(2.f * x, 0.f);
}
float dqgelu(const float x) {
const float sigmoid = 1.f / (1 + expf(-1.702f * x));
return 1.702f * x * sigmoid * (1.f - sigmoid) + sigmoid;
}
} // namespace
template <float (*act)(const float), typename IT, typename OT, typename CT> template <float (*act)(const float), typename IT, typename OT, typename CT>
void compute_ref_act_cast(const IT *input_h, void compute_ref_act_cast(const IT *input_h,
OT *output_h, OT *output_h,
...@@ -82,6 +30,7 @@ void compute_ref_act_cast(const IT *input_h, ...@@ -82,6 +30,7 @@ void compute_ref_act_cast(const IT *input_h,
const size_t H) { const size_t H) {
CT amax = 0.; CT amax = 0.;
#pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < N; i++) { for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) { for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]); CT elt = static_cast<CT>(input_h[i * H + j]);
...@@ -101,6 +50,7 @@ void compute_ref_dact_cast(const IT *input_h, ...@@ -101,6 +50,7 @@ void compute_ref_dact_cast(const IT *input_h,
const size_t N, const size_t N,
const size_t H) { const size_t H) {
using CT = float; using CT = float;
#pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t i = 0; i < N; i++) { for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) { for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]); CT elt = static_cast<CT>(input_h[i * H + j]);
...@@ -118,6 +68,7 @@ void compute_ref_glu_act_cast(const IT *input_h, OT *output_h, const CT scale, C ...@@ -118,6 +68,7 @@ void compute_ref_glu_act_cast(const IT *input_h, OT *output_h, const CT scale, C
const int col = H * 2; const int col = H * 2;
#pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < N; i++) { for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) { for (size_t j = 0; j < H; j++) {
CT gelu_elt = static_cast<CT>(input_h[i * col + j]); CT gelu_elt = static_cast<CT>(input_h[i * col + j]);
...@@ -139,6 +90,7 @@ void compute_ref_dglu_act_cast(const IT *input_h, const IT *grad_h, OT *output_h ...@@ -139,6 +90,7 @@ void compute_ref_dglu_act_cast(const IT *input_h, const IT *grad_h, OT *output_h
const int col = H * 2; const int col = H * 2;
using CT = float; using CT = float;
#pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t i = 0; i < N; i++) { for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) { for (size_t j = 0; j < H; j++) {
CT grad = static_cast<CT>(grad_h[i * H + j]); CT grad = static_cast<CT>(grad_h[i * H + j]);
...@@ -164,10 +116,10 @@ void performTest(const size_t N, const size_t H) { ...@@ -164,10 +116,10 @@ void performTest(const size_t N, const size_t H) {
DType itype = TypeInfo<IType>::dtype; DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype; DType otype = TypeInfo<OType>::dtype;
Tensor input({ N, H }, itype); Tensor input("input", { N, H }, itype);
Tensor output({ N, H }, otype); Tensor output("output", { N, H }, otype);
Tensor igrad({ N, H }, itype); Tensor igrad("igrad", { N, H }, itype);
Tensor ograd({ N, H }, itype); Tensor ograd("ograd", { N, H }, itype);
fillUniform(&input); fillUniform(&input);
fillUniform(&ograd); fillUniform(&ograd);
...@@ -179,7 +131,7 @@ void performTest(const size_t N, const size_t H) { ...@@ -179,7 +131,7 @@ void performTest(const size_t N, const size_t H) {
nvte_act(input.data(), output.data(), 0); nvte_act(input.data(), output.data(), 0);
float ref_amax; float ref_amax;
compute_ref_act_cast<ref_act>(input.cpu_dptr<IType>(), ref_output.get(), compute_ref_act_cast<ref_act>(input.rowwise_cpu_dptr<IType>(), ref_output.get(),
output.scale(), &ref_amax, N, H); output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
...@@ -195,7 +147,7 @@ void performTest(const size_t N, const size_t H) { ...@@ -195,7 +147,7 @@ void performTest(const size_t N, const size_t H) {
nvte_dact(ograd.data(), input.data(), igrad.data(), 0); nvte_dact(ograd.data(), input.data(), igrad.data(), 0);
compute_ref_dact_cast<ref_dact>(input.cpu_dptr<IType>(), ograd.cpu_dptr<IType>(), compute_ref_dact_cast<ref_dact>(input.rowwise_cpu_dptr<IType>(), ograd.rowwise_cpu_dptr<IType>(),
ref_igrad.get(), N, H); ref_igrad.get(), N, H);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
...@@ -219,10 +171,10 @@ void performTestGLU(const size_t N, const size_t H) { ...@@ -219,10 +171,10 @@ void performTestGLU(const size_t N, const size_t H) {
DType itype = TypeInfo<IType>::dtype; DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype; DType otype = TypeInfo<OType>::dtype;
Tensor input({N, H * 2}, itype); Tensor input("input", {N, H * 2}, itype);
Tensor output({N, H}, otype); Tensor output("output", {N, H}, otype);
Tensor igrad({ N, H * 2 }, itype); Tensor igrad("igrad", { N, H * 2 }, itype);
Tensor ograd({ N, H }, itype); Tensor ograd("ograd", { N, H }, itype);
fillUniform(&input); fillUniform(&input);
fillUniform(&ograd); fillUniform(&ograd);
...@@ -234,7 +186,7 @@ void performTestGLU(const size_t N, const size_t H) { ...@@ -234,7 +186,7 @@ void performTestGLU(const size_t N, const size_t H) {
nvte_act(input.data(), output.data(), 0); nvte_act(input.data(), output.data(), 0);
float ref_amax; float ref_amax;
compute_ref_glu_act_cast<ref_act>(input.cpu_dptr<IType>(), ref_output.get(), compute_ref_glu_act_cast<ref_act>(input.rowwise_cpu_dptr<IType>(), ref_output.get(),
output.scale(), &ref_amax, N, H); output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
...@@ -242,15 +194,19 @@ void performTestGLU(const size_t N, const size_t H) { ...@@ -242,15 +194,19 @@ void performTestGLU(const size_t N, const size_t H) {
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol, rtol] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", output.amax(), ref_amax, atol, rtol);
if (output.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
const float ref_scale = 1.f / output.scale();
compareResults("scale_inv", *output.rowwise_cpu_scale_inv_ptr<float>(), ref_scale, atol, rtol);
}
} }
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol); compareResults("output_gelu", output, ref_output.get(), atol, rtol);
nvte_dact(ograd.data(), input.data(), igrad.data(), 0); nvte_dact(ograd.data(), input.data(), igrad.data(), 0);
compute_ref_dglu_act_cast<ref_dact, ref_act>(input.cpu_dptr<IType>(), ograd.cpu_dptr<IType>(), compute_ref_dglu_act_cast<ref_dact, ref_act>(input.rowwise_cpu_dptr<IType>(), ograd.rowwise_cpu_dptr<IType>(),
ref_igrad.get(), N, H); ref_igrad.get(), N, H);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const InputType *data, OutputType *output_c,
const size_t size,
float *amax, float scale) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < size; ++i) {
compute_t current = static_cast<compute_t>(data[i]);
current_max = fmaxf(current_max, fabsf(current));
output_c[i] = OutputType(scale * current);
}
*amax = current_max;
}
template <typename InputType, typename OutputType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
const size_t full_size = product(shape);
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", shape, itype);
Tensor output_c("output_c", shape, otype);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(full_size);
fillUniform(&input);
setRandomScale(&output_c);
nvte_quantize(input.data(), output_c.data(), 0);
float ref_amax;
compute_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output_c.get(),
full_size, &ref_amax, output_c.scale());
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
}
std::vector<std::vector<size_t>> test_cases = {
{16},
{16000},
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{65536, 160},
{16384, 1616},
{1, 128},
{1, 1296},
{1, 16},
{5, 160},
{5, 4, 3, 160},
{217, 256},
};
} // namespace
class CastTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(CastTestSuite, TestCast) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename IT, typename OT, typename CT>
void compute_ref_cast_dbias(const IT *input_h,
const CT scale,
OT *output_c_h,
CT *amax_h,
IT *dbias_h,
const size_t N,
const size_t H) {
CT amax = 0.;
std::vector<CT> acc_dbias(H, 0.);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]);
// update amax
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
output_c_h[i * H + j] = static_cast<OT>(scale * elt);
// dbias
acc_dbias[j] += elt;
}
}
*amax_h = amax;
for (size_t i = 0; i < H; i++) {
dbias_h[i] = static_cast<IT>(acc_dbias[i]);
}
}
template <typename IType, typename OType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
const size_t N = first_dimension(shape);
const size_t H = last_dimension(shape);
Tensor input("input", shape, itype);
Tensor output_c("output_c", shape, otype);
// dbias has the same data type with "output grad"
Tensor dbias("dbias", {H}, itype);
fillUniform(&input);
setRandomScale(&output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_output_dbias = std::make_unique<IType[]>(H);
CType ref_amax;
compute_ref_cast_dbias(input.rowwise_cpu_dptr<IType>(),
output_c.scale(),
ref_output_c.get(),
&ref_amax,
ref_output_dbias.get(),
N, H);
Tensor workspace;
nvte_quantize_dbias(input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias(input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
rtol_dbias *= 4;
compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
std::vector<std::vector<size_t>> test_cases = {
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{65536, 160},
{16384, 1616},
{1, 128},
{1, 1296},
{1, 16},
{5, 160},
{5, 4, 3, 160},
{217, 256},
};
} // namespace;
class CastDBiasTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(CastDBiasTestSuite, TestCastDBias) {
using namespace transformer_engine;
using namespace test;
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastDBiasTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastDBiasTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <typename IT, typename OT, typename CT>
void compute_ref_cast_dbias_dgelu(const IT *input,
const IT *grad,
const CT scale,
OT *output_c,
CT *amax_h,
IT *dbias,
const size_t N,
const size_t H) {
CT amax = 0.;
std::vector<CT> acc_dbias(H, 0.);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT in_elt = static_cast<CT>(input[i * H + j]);
const CT in_grad = static_cast<CT>(grad[i * H + j]);
const CT elt = in_grad * static_cast<float>(dgelu(static_cast<float>(in_elt)));
const CT elt_abs = std::abs(elt);
// update amax
if (elt_abs > amax) {
amax = elt_abs;
}
output_c[i * H + j] = static_cast<OT>(scale * elt);
// dbias
acc_dbias[j] += elt;
}
}
*amax_h = amax;
for (size_t i = 0; i < H; i++) {
dbias[i] = static_cast<IT>(acc_dbias[i]);
}
}
template <typename IType, typename OType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
const size_t N = first_dimension(shape);
const size_t H = last_dimension(shape);
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype);
// dbias has the same data type with "output grad"
Tensor dbias("dbias", {H}, itype);
fillUniform(&input);
fillUniform(&grad);
setRandomScale(&output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_output_dbias = std::make_unique<IType[]>(H);
CType ref_amax;
compute_ref_cast_dbias_dgelu(input.rowwise_cpu_dptr<IType>(),
grad.rowwise_cpu_dptr<IType>(),
output_c.scale(),
ref_output_c.get(),
&ref_amax,
ref_output_dbias.get(),
N, H);
Tensor workspace;
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
rtol_dbias *= 4;
compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
std::vector<std::vector<size_t>> test_cases = {
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{65536, 160},
{16384, 1616},
{1, 128},
{1, 1296},
{1, 16},
{5, 160},
{5, 4, 3, 160},
{217, 256},
};
} // namespace;
class CastDBiasDGeluTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(CastDBiasDGeluTestSuite, TestCastDBiasDgelu) {
using namespace transformer_engine;
using namespace test;
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastDBiasDGeluTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastDBiasDGeluTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <omp.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <typename IType, typename OType>
void compute_ref_cast_dgated_swiglu(const IType * const grad,
const IType * const input,
const float scale,
OType * const output,
float * const amax_ptr,
const size_t rows,
const size_t cols) {
float amax = 0;
const size_t stride = cols * 2;
#pragma omp parallel for reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < rows; i++) {
for (size_t j = 0; j < cols; j++) {
float grad_elt = static_cast<float>(grad[i * cols + j]);
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
float after_dgate = grad_elt * silu(silu_elt);
if (abs(after_dsilu) > amax) { amax = abs(after_dsilu); }
if (abs(after_dgate) > amax) { amax = abs(after_dgate); }
output[i * stride + j] = static_cast<OType>(scale * after_dsilu);
output[i * stride + cols + j] = static_cast<OType>(scale * after_dgate);
}
}
*amax_ptr = amax;
}
template <typename IType, typename OType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
std::vector<size_t> input_shape = shape;
input_shape[input_shape.size() - 1] *= 2;
const size_t input_size = product(input_shape);
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);
Tensor grad("grad", shape, itype);
Tensor input("input", input_shape, itype);
Tensor output_c("output_c", input_shape, otype);
fillUniform(&grad);
fillUniform(&input);
setRandomScale(&output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(input_size);
nvte_dswiglu(grad.data(), input.data(), output_c.data(), 0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax;
compute_ref_cast_dgated_swiglu(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(),
output_c.scale(),
ref_output_c.get(),
&ref_amax,
rows,
cols);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
}
std::vector<std::vector<size_t>> test_cases = {
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{217, 256},
{1296},
{5, 4, 3, 160},
};
} // namespace
class CastSwiGLUTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::vector<size_t>>> {};
TEST_P(CastSwiGLUTestSuite, TestCastSwiGLU) {
using namespace transformer_engine;
using namespace test;
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
if (size.back() % 32 != 0) {
GTEST_SKIP();
}
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType, performTest<InputType, OutputType>(size);););
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest, CastSwiGLUTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastSwiGLUTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
using namespace test;
namespace {
enum ProcessingMethod {
CAST_ONLY,
CAST_DBIAS,
CAST_DBIAS_DACT,
CAST_DACT,
CAST_ACT
};
enum ActivationType {
Identity,
GeLU,
SiLU,
ReLU,
QGeLU,
SReLU
};
template <typename InputType, typename OutputType, float (*OP)(const float)>
void scale_block(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_c,
float* dbias,
fp8e8m0* output_scales,
const size_t scale_idx,
const size_t i_min,
const size_t i_max,
const size_t j_min,
const size_t j_max,
const size_t cols) {
float amax = 0.0f;
// Find the absolute maximum value in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input
elt = static_cast<float>(grad[idx]);
}
if (processing_method != ProcessingMethod::CAST_ONLY
&& processing_method != ProcessingMethod::CAST_DBIAS) {
elt = OP(elt);
}
if (processing_method == ProcessingMethod::CAST_DACT ||
processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]);
}
dbias[j] += elt;
if (isinf(elt) || isnan(elt)) {
continue;
}
amax = std::max(amax, std::abs(elt));
}
}
const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits<OutputType>::max_reciprocal());
const float scale_reciprocal = exp2f_rcp(biased_exponent);
output_scales[scale_idx] = biased_exponent;
// Quantize elements in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input
elt = static_cast<float>(grad[idx]);
}
if (processing_method != ProcessingMethod::CAST_ONLY
&& processing_method != ProcessingMethod::CAST_DBIAS) {
elt = OP(elt);
}
if (processing_method == ProcessingMethod::CAST_DACT ||
processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]);
}
output_c[idx] = static_cast<OutputType>(elt * scale_reciprocal);
}
}
}
template <typename InputType, typename OutputType, float (*OP)(const float)>
void compute_ref_x1(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_c,
fp8e8m0* output_scales,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride)
{
std::vector<float> output_dbias_fp32(cols, 0);
const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y;
const size_t blocks_X = (cols + block_size_X - 1) / block_size_X;
for (size_t ii = 0; ii < blocks_Y; ++ii) {
const size_t i_min = ii * block_size_Y;
const size_t i_max = std::min((ii + 1) * block_size_Y, rows);
for (size_t jj = 0; jj < blocks_X; ++jj) {
const size_t j_min = jj * block_size_X;
const size_t j_max = std::min((jj + 1) * block_size_X, cols);
const size_t scale_idx = ii * scales_stride + jj;
scale_block<InputType, OutputType, OP>(
processing_method, input, grad, output_c, output_dbias_fp32.data(),
output_scales, scale_idx, i_min, i_max, j_min, j_max, cols);
}
}
for (size_t j = 0; j < cols; ++j) {
output_dbias[j] = static_cast<InputType>(output_dbias_fp32[j]);
}
}
template <typename InputType, typename OutputType, float (*OP)(const float)>
void compute_ref_x2(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_rowwise,
OutputType* output_colwise,
fp8e8m0* scales_rowwise,
fp8e8m0* scales_colwise,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise) {
compute_ref_x1<InputType, OutputType, OP>(
processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias,
rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<InputType, OutputType, OP>(
processing_method, input, grad, output_colwise, scales_colwise, output_dbias,
rows, cols, block_size_Y, 1, scales_stride_colwise);
}
/**
* Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* OR
* 2) Scaled columns + column-wise scaling factors
*/
template <typename InputType, typename OutputType, float (*OP)(const float)>
void performTest_x1(const ProcessingMethod processing_method,
const std::vector<size_t>& shape,
const bool rowwise,
const bool colwise,
InputsFillCase fill_case) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);
if (shape.size() < 2 && colwise) {
GTEST_SKIP();
}
const size_t block_size_rows = rowwise ? 1 : 32;
const size_t block_size_cols = colwise ? 1 : 32;
const std::array<size_t,4> scale_dims = get_scale_tensor_dims(rows, cols, block_size_rows,
block_size_cols);
const size_t unpadded_blocks_Y = scale_dims[0];
const size_t unpadded_blocks_X = scale_dims[1];
const size_t blocks_Y = scale_dims[2];
const size_t blocks_X = scale_dims[3];
const size_t scales_stride = blocks_X;
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING);
Tensor output_dbias("output_dbias", { cols }, itype);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<InputType[]> ref_output_dbias = std::make_unique<InputType[]>(cols);
std::unique_ptr<fp8e8m0[]> ref_output_scales = std::make_unique<fp8e8m0[]>(blocks_Y * blocks_X);
fillCase<EncodingType>(&input, fill_case);
fillUniform(&grad);
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_quantize(input.data(), output_c.data(), 0);
break;
}
case ProcessingMethod::CAST_DBIAS: {
nvte_quantize_dbias(grad.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias(grad.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
break;
}
case ProcessingMethod::CAST_DBIAS_DACT: {
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
break;
}
case ProcessingMethod::CAST_DACT: {
nvte_dgelu(grad.data(), input.data(), output_c.data(), 0);
break;
}
case ProcessingMethod::CAST_ACT: {
nvte_gelu(input.data(), output_c.data(), 0);
break;
}
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
compute_ref_x1<InputType, OutputType, OP>(processing_method,
input.rowwise_cpu_dptr<InputType>(),
grad.rowwise_cpu_dptr<InputType>(),
ref_output_c.get(),
ref_output_scales.get(),
ref_output_dbias.get(),
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol);
const uint8_t * const gpu_scales_ptr = rowwise
? output_c.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output_c.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
rtol_dbias *= sqrt(static_cast<double>(rows)) ;
} else {
rtol_dbias *= 4;
}
compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
}
/**
* Scaling along both dimensions (rows and columns)
* Produces two sets of scaled output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* AND
* 2) Scaled columns + column-wise scaling factors
*/
template <typename InputType, typename OutputType, float (*OP)(const float)>
void performTest_x2(const ProcessingMethod processing_method,
const std::vector<size_t>& shape,
const size_t block_size_rows,
const size_t block_size_cols,
InputsFillCase fill_case) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
if (shape.size() < 2) {
GTEST_SKIP();
}
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);
const std::array<size_t,4> scale_dims_rowwise = get_scale_tensor_dims(rows, cols, 1, 32);
const std::array<size_t,4> scale_dims_colwise = get_scale_tensor_dims(rows, cols, 32, 1);
const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0];
const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1];
const size_t blocks_Y_rowwise = scale_dims_rowwise[2];
const size_t blocks_X_rowwise = scale_dims_rowwise[3];
const size_t scales_stride_rowwise = blocks_X_rowwise;
const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0];
const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1];
const size_t blocks_Y_colwise = scale_dims_colwise[2];
const size_t blocks_X_colwise = scale_dims_colwise[3];
const size_t scales_stride_colwise = blocks_X_colwise;
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output("output", shape, otype, true, true, NVTE_MXFP8_1D_SCALING);
Tensor output_dbias("output_dbias", { cols }, itype);
std::unique_ptr<OutputType[]> ref_output_c_rowwise = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<OutputType[]> ref_output_c_colwise = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<fp8e8m0[]> ref_scales_rowwise = std::make_unique<fp8e8m0[]>(blocks_Y_rowwise * blocks_X_rowwise);
std::unique_ptr<fp8e8m0[]> ref_scales_colwise = std::make_unique<fp8e8m0[]>(blocks_Y_colwise * blocks_X_colwise);
std::unique_ptr<InputType[]> ref_output_dbias = std::make_unique<InputType[]>(cols);
fillCase<EncodingType>(&input, fill_case);
fillUniform(&grad);
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_quantize(input.data(), output.data(), 0);
break;
}
case ProcessingMethod::CAST_DBIAS: {
nvte_quantize_dbias(grad.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias(grad.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
break;
}
case ProcessingMethod::CAST_DBIAS_DACT: {
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
break;
}
case ProcessingMethod::CAST_DACT: {
nvte_dgelu(grad.data(), input.data(), output.data(), 0);
break;
}
case ProcessingMethod::CAST_ACT: {
nvte_gelu(input.data(), output.data(), 0);
break;
}
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
compute_ref_x2<InputType, OutputType, OP>(processing_method,
input.rowwise_cpu_dptr<InputType>(),
grad.rowwise_cpu_dptr<InputType>(),
ref_output_c_rowwise.get(),
ref_output_c_colwise.get(),
ref_scales_rowwise.get(),
ref_scales_colwise.get(),
ref_output_dbias.get(),
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride_rowwise,
scales_stride_colwise);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol);
compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol);
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise);
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise);
if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
rtol_dbias *= sqrt(static_cast<double>(rows)) ;
} else {
rtol_dbias *= 4;
}
compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
}
std::vector<std::vector<size_t>> matrix_sizes = {
{1, 16},
{16, 48},
{65, 96},
{128, 128},
{256, 256},
{993, 512},
{256, 65536},
{2048, 6144},
{16384, 128},
{32768, 160},
{4096, 1632},
{1024},
{8, 32, 1024},
{16, 8, 4, 512},
};
std::vector<std::pair<size_t, size_t>> block_sizes = {
{1, 32},
{32, 1},
{32, 32},
};
std::vector<InputsFillCase> input_scenarios = {
InputsFillCase::uniform,
// InputsFillCase::zeros,
// InputsFillCase::zero_to_minNorm,
// InputsFillCase::minNorm_to_maxNorm,
// InputsFillCase::maxNorm_to_inf
};
std::vector<ProcessingMethod> processing_methods = {
ProcessingMethod::CAST_ONLY,
ProcessingMethod::CAST_DBIAS,
ProcessingMethod::CAST_DBIAS_DACT,
ProcessingMethod::CAST_DACT,
ProcessingMethod::CAST_ACT,
};
// Only GeLU activation tests are supported
std::vector<ActivationType> Activation_types = {
ActivationType::Identity,
ActivationType::GeLU,
// ActivationType::SiLU,
// ActivationType::ReLU,
// ActivationType::QGeLU,
// ActivationType::SReLU,
};
} // namespace
class FusedCastMXFP8TestSuite : public ::testing::TestWithParam
<std::tuple<ProcessingMethod,
ActivationType,
std::vector<size_t>,
std::pair<size_t, size_t>,
transformer_engine::DType,
transformer_engine::DType,
InputsFillCase>> {};
#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \
case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \
case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \
case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \
case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \
}
#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \
case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \
case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \
case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \
case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \
}
TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) {
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const ProcessingMethod processing_method = std::get<0>(GetParam());
const ActivationType Act_type = std::get<1>(GetParam());
const auto matrix_size = std::get<2>(GetParam());
const auto block_size = std::get<3>(GetParam());
const DType input_type = std::get<4>(GetParam());
const DType output_type = std::get<5>(GetParam());
const InputsFillCase fill_case = std::get<6>(GetParam());
// Skips non Act tests if the Activation type is not an identity
if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
&& Act_type != ActivationType::Identity) {
GTEST_SKIP();
}
// Skips Act tests if the Activation is an identity
if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
|| processing_method == ProcessingMethod::CAST_DACT
|| processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) {
GTEST_SKIP();
}
const bool rowwise = block_size.second != 1;
const bool colwise = block_size.first != 1;
if (processing_method == ProcessingMethod::CAST_ACT) {
// Forward activations
ACT_FUNC_SWITCH(Act_type, OP,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType, OP>(
processing_method, matrix_size,
rowwise, colwise, fill_case);
} else {
performTest_x2<InputType, OutputType, OP>(
processing_method, matrix_size,
block_size.first, block_size.second, fill_case);
}
);
);
);
} else {
DACT_FUNC_SWITCH(Act_type, OP,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType, OP>(
processing_method, matrix_size,
rowwise, colwise, fill_case);
} else {
performTest_x2<InputType, OutputType, OP>(
processing_method, matrix_size,
block_size.first, block_size.second, fill_case);
}
);
);
);
}
}
std::string to_string(const ProcessingMethod method) {
switch (method) {
case ProcessingMethod::CAST_ONLY: return "CAST_ONLY";
case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS";
case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT";
case ProcessingMethod::CAST_DACT: return "CAST_DACT";
case ProcessingMethod::CAST_ACT: return "CAST_ACT";
default: return "";
}
}
std::string to_string(const ActivationType Act_type) {
switch (Act_type) {
case ActivationType::Identity: return "Identity";
case ActivationType::GeLU: return "GeLU";
case ActivationType::SiLU: return "SiLU";
case ActivationType::ReLU: return "ReLU";
case ActivationType::QGeLU: return "QGeLU";
case ActivationType::SReLU: return "SReLU";
default: return "";
}
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
FusedCastMXFP8TestSuite,
::testing::Combine(
::testing::ValuesIn(processing_methods),
::testing::ValuesIn(Activation_types),
::testing::ValuesIn(matrix_sizes),
::testing::ValuesIn(block_sizes),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios)),
[](const testing::TestParamInfo<FusedCastMXFP8TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param)) + "X" +
to_string(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
name += "X" + std::to_string(std::get<3>(info.param).first) +
"X" + std::to_string(std::get<3>(info.param).second) +
"X" + test::typeName(std::get<4>(info.param)) +
"X" + test::typeName(std::get<5>(info.param)) +
"X" + test::caseName(std::get<6>(info.param));
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <bool IS_DGATED, typename IType, typename OType>
void scale_block(const IType* grad,
const IType* input,
OType* output,
fp8e8m0* output_scales,
const size_t scale_idx,
const size_t scale_idx_gate,
float& thread_amax,
const size_t i_min,
const size_t i_max,
const size_t j_min,
const size_t j_max,
const size_t cols) {
float block_amax = 0.0f;
float block_amax_gate = 0.0f;
const size_t stride = cols * 2;
// Find the absolute maximum value in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
float gated_amax_act = 0;
float gated_amax_gate = 0;
if constexpr (IS_DGATED) {
const float grad_elt = static_cast<float>(grad[i * cols + j]);
const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
const float after_dgate = silu(silu_elt) * grad_elt;
gated_amax_act = abs(after_dsilu);
gated_amax_gate = abs(after_dgate);
} else {
const float after_silu = silu(silu_elt) * gate_elt;
gated_amax_act = abs(after_silu);
}
if (gated_amax_act > block_amax) { block_amax = gated_amax_act; }
if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; }
}
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax *
Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal = exp2f_rcp(biased_exponent);
output_scales[scale_idx] = biased_exponent;
float scale_reciprocal_gate = 1;
if constexpr (IS_DGATED) {
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate *
Quantized_Limits<OType>::max_reciprocal());
scale_reciprocal_gate = exp2f_rcp(biased_exponent);
output_scales[scale_idx_gate] = biased_exponent;
}
// Quantize elements in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
if constexpr (IS_DGATED) {
const float grad_elt = static_cast<float>(grad[i * cols + j]);
const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
const float after_dgate = silu(silu_elt) * grad_elt;
output[i * stride + j] = static_cast<OType>(after_dsilu * scale_reciprocal);
output[i * stride + cols + j] = static_cast<OType>(after_dgate *
scale_reciprocal_gate);
} else {
const float after_silu = silu(silu_elt) * gate_elt;
output[i * cols + j] = static_cast<OType>(after_silu * scale_reciprocal);
}
}
}
thread_amax = std::max(thread_amax, block_amax);
thread_amax = std::max(thread_amax, block_amax_gate);
}
template <bool IS_DGATED, typename IType, typename OType>
void compute_ref_x1(const IType* grad,
const IType* input,
OType* output,
fp8e8m0* output_scales,
float& ref_amax,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride) {
const size_t tile_size_Y = std::max(32lu, block_size_Y);
const size_t tile_size_X = std::max(64lu, block_size_X);
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
float amax = 0;
#pragma omp parallel reduction(max: amax) proc_bind(spread)
{
float thread_amax = 0;
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii;
const size_t block_offset_Y = ii * block_size_Y;
const size_t i_min = tile_offset_Y + block_offset_Y;
if (i_min >= rows) continue;
const size_t i_max = std::min(i_min + block_size_Y, rows);
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) {
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
const size_t block_offset_X = jj * block_size_X;
const size_t j_min = tile_offset_X + block_offset_X;
if (j_min >= cols) continue;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X;
const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X +
cols / block_size_X;
scale_block<IS_DGATED, IType, OType>(
grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate,
thread_amax, i_min, i_max, j_min, j_max, cols);
}
}
}
if (thread_amax > amax) {
amax = thread_amax;
}
}
ref_amax = amax;
}
template <bool IS_DGATED, typename IType, typename OType>
void compute_ref_x2(const IType* grad,
const IType* input,
OType* output_rowwise,
OType* output_colwise,
fp8e8m0* scales_rowwise,
fp8e8m0* scales_colwise,
float& ref_amax,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise) {
compute_ref_x1<IS_DGATED, IType, OType>(
grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<IS_DGATED, IType, OType>(
grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise);
}
/**
* Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* OR
* 2) Scaled columns + column-wise scaling factors
*/
template <bool IS_DGATED, typename IType, typename OType>
void performTest_x1(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols,
InputsFillCase fill_case) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
const bool rowwise = (block_size_rows == 1) && (block_size_cols == 32);
const bool colwise = (block_size_rows == 32) && (block_size_cols == 1);
NVTE_CHECK(rowwise || colwise);
// std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl;
// std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl;
// std::cout << "blocks_Y: " << blocks_Y << std::endl;
// std::cout << "blocks_X: " << blocks_X << std::endl;
// std::cout << "scales_stride: " << scales_stride << std::endl;
Tensor grad("grad", { rows, cols }, itype);
Tensor input("input", { rows, cols * 2 }, itype);
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
const std::array<size_t,4> scale_dims = get_scale_tensor_dims(rows, output_cols, block_size_rows,
block_size_cols);
const size_t unpadded_blocks_Y = scale_dims[0];
const size_t unpadded_blocks_X = scale_dims[1];
const size_t blocks_Y = scale_dims[2];
const size_t blocks_X = scale_dims[3];
const size_t scales_stride = blocks_X;
Tensor output("output", std::vector<size_t>{ rows, output_cols }, otype,
rowwise, colwise, NVTE_MXFP8_1D_SCALING);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(rows * output_cols);
std::unique_ptr<fp8e8m0[]> ref_output_scales = std::make_unique<fp8e8m0[]>(blocks_Y * blocks_X);
for (size_t i = 0; i < blocks_Y * blocks_X; ++i) {
ref_output_scales[i] = 0;
}
// fillCase<EncodingType>(&grad, fill_case);
if constexpr (IS_DGATED) {
fillUniform(&grad);
}
fillUniform(&input);
if constexpr (IS_DGATED) {
nvte_dswiglu(grad.data(), input.data(), output.data(), 0);
} else {
nvte_swiglu(input.data(), output.data(), 0);
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax = 0;
compute_ref_x1<IS_DGATED, IType, OType>(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(),
ref_output.get(),
ref_output_scales.get(),
ref_amax,
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride);
auto [atol, rtol] = getTolerances(otype);
compareResults("output", output, ref_output.get(), rowwise, atol, rtol);
const uint8_t * const gpu_scales_ptr = rowwise
? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
if (rowwise) {
compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
} else {
compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
}
}
/**
* Scaling along both dimensions (rows and columns)
* Produces two sets of scaled output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* AND
* 2) Scaled columns + column-wise scaling factors
*/
template <bool IS_DGATED, typename IType, typename OType>
void performTest_x2(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols,
InputsFillCase fill_case) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor grad("grad", { rows, cols }, itype);
Tensor input("input", { rows, cols * 2 }, itype);
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
const std::array<size_t,4> scale_dims_rowwise = get_scale_tensor_dims(rows, output_cols, 1, 32);
const std::array<size_t,4> scale_dims_colwise = get_scale_tensor_dims(rows, output_cols, 32, 1);
const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0];
const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1];
const size_t blocks_Y_rowwise = scale_dims_rowwise[2];
const size_t blocks_X_rowwise = scale_dims_rowwise[3];
const size_t scales_stride_rowwise = blocks_X_rowwise;
const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0];
const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1];
const size_t blocks_Y_colwise = scale_dims_colwise[2];
const size_t blocks_X_colwise = scale_dims_colwise[3];
const size_t scales_stride_colwise = blocks_X_colwise;
Tensor output("output", std::vector<size_t>{ rows, output_cols }, otype,
true, true, NVTE_MXFP8_1D_SCALING);
std::unique_ptr<OType[]> ref_output_rowwise = std::make_unique<OType[]>(rows * output_cols);
std::unique_ptr<OType[]> ref_output_colwise = std::make_unique<OType[]>(rows * output_cols);
std::unique_ptr<fp8e8m0[]> ref_scales_rowwise = std::make_unique<fp8e8m0[]>(blocks_Y_rowwise * blocks_X_rowwise);
std::unique_ptr<fp8e8m0[]> ref_scales_colwise = std::make_unique<fp8e8m0[]>(blocks_Y_colwise * blocks_X_colwise);
for (size_t i = 0; i < blocks_Y_rowwise * blocks_X_rowwise; ++i) {
ref_scales_rowwise[i] = 0;
}
for (size_t i = 0; i < blocks_Y_colwise * blocks_X_colwise; ++i) {
ref_scales_colwise[i] = 0;
}
// fillCase<EncodingType>(&grad, fill_case);
if constexpr (IS_DGATED) {
fillUniform(&grad);
}
fillUniform(&input);
if constexpr (IS_DGATED) {
nvte_dswiglu(grad.data(), input.data(), output.data(), 0);
} else {
nvte_swiglu(input.data(), output.data(), 0);
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax = 0;
compute_ref_x2<IS_DGATED, IType, OType>(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(),
ref_output_rowwise.get(),
ref_output_colwise.get(),
ref_scales_rowwise.get(),
ref_scales_colwise.get(),
ref_amax,
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride_rowwise,
scales_stride_colwise);
auto [atol, rtol] = getTolerances(otype);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol);
compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol);
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise);
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise);
}
std::vector<std::pair<size_t, size_t>> matrix_sizes = {
{1, 32},
{16, 64},
{65, 96},
{128, 128},
{256, 256},
{993, 512},
{768, 1024},
{65536, 128},
{16384, 1632},
};
std::vector<std::pair<size_t, size_t>> block_sizes = {
{1, 32},
{32, 1},
{32, 32},
};
std::vector<InputsFillCase> input_scenarios = {
InputsFillCase::uniform,
// InputsFillCase::zeros,
// InputsFillCase::zero_to_minNorm,
// InputsFillCase::minNorm_to_maxNorm,
// InputsFillCase::maxNorm_to_inf
};
std::vector<bool> is_dgated_op = {
true,
false
};
} // namespace
class CastMXFP8_GatedActTestSuite : public ::testing::TestWithParam
<std::tuple<std::pair<size_t, size_t>,
std::pair<size_t, size_t>,
transformer_engine::DType,
transformer_engine::DType,
InputsFillCase,
bool>> {};
TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) {
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const auto matrix_size = std::get<0>(GetParam());
const auto block_size = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
const DType output_type = std::get<3>(GetParam());
const InputsFillCase fill_case = std::get<4>(GetParam());
const bool IS_DGATED = std::get<5>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType,
if (block_size.first == 1 || block_size.second == 1) {
if (IS_DGATED) {
performTest_x1<true, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
} else {
performTest_x1<false, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
}
} else {
if (IS_DGATED) {
performTest_x2<true, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
} else {
performTest_x2<false, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
}
}
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastMXFP8_GatedActTestSuite,
::testing::Combine(
::testing::ValuesIn(matrix_sizes),
::testing::ValuesIn(block_sizes),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios),
::testing::ValuesIn(is_dgated_op)),
[](const testing::TestParamInfo<CastMXFP8_GatedActTestSuite::ParamType>& info) {
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" +
std::to_string(std::get<1>(info.param).first) + "X" +
std::to_string(std::get<1>(info.param).second) + "X" +
test::typeName(std::get<2>(info.param)) + "X" +
test::typeName(std::get<3>(info.param)) + "X" +
test::caseName(std::get<4>(info.param)) + "X" +
(std::get<5>(info.param) ? "DGATED" : "GATED");
return name;
});
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <transformer_engine/transpose.h> #include <transformer_engine/cast.h>
#include "../test_common.h" #include "../test_common.h"
using namespace transformer_engine; using namespace transformer_engine;
...@@ -45,36 +45,34 @@ void performTest(const size_t N, const size_t H) { ...@@ -45,36 +45,34 @@ void performTest(const size_t N, const size_t H) {
DType itype = TypeInfo<InputType>::dtype; DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype; DType otype = TypeInfo<OutputType>::dtype;
Tensor input({ N, H }, itype); Tensor input("input", { N, H }, itype);
Tensor output_c({ N, H }, otype); Tensor output("output", { N, H }, otype, true, true);
Tensor output_t({ H, N }, otype);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(N * H); std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<OutputType[]> ref_output_t = std::make_unique<OutputType[]>(N * H); std::unique_ptr<OutputType[]> ref_output_t = std::make_unique<OutputType[]>(N * H);
fillUniform(&input); fillUniform(&input);
setRandomScale(&output_c); setRandomScale(&output);
output_t.shareFP8Meta(output_c);
nvte_cast_transpose(input.data(), output_c.data(), output_t.data(), 0); nvte_quantize(input.data(), output.data(), 0);
float ref_amax; float ref_amax;
compute_ref<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output_c.get(), compute_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output_c.get(),
ref_output_t.get(), N, H, &ref_amax, ref_output_t.get(), N, H, &ref_amax,
output_c.scale()); output.scale());
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto err = cudaGetLastError(); auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) { if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale(); float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); compareResults("output_c", output, ref_output_c.get(), true, atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); compareResults("output_t", output, ref_output_t.get(), false, atol, rtol);
} }
std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288}, std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment