Commit 544dd14b authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Update main branch with TE 2.0 code, update version to 2.1.0.dev0


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent e5369541
......@@ -2,7 +2,6 @@
extension-pkg-whitelist=flash_attn_2_cuda,
torch,
transformer_engine_torch,
transformer_engine_paddle,
transformer_engine_jax
extension-pkg-allow-list=transformer_engine.transformer_engine_jax
......
......@@ -8,7 +8,7 @@ pip install "nltk>=3.8.2"
pip install pytest==8.2.1
: ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed'
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py
# Test without custom calls
NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
echo "Checking common API headers"
cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include
echo "Checking C++ files"
cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common
cpplint --recursive transformer_engine/paddle
fi
if [ -z "${CPP_ONLY}" ]
then
cd $TE_PATH
echo "Checking Python files"
pylint --recursive=y transformer_engine/common transformer_engine/paddle
fi
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -xe
pip install pytest==8.2.1
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/paddle
pytest -Wignore -v $TE_PATH/examples/paddle/mnist
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
# Install dependencies
# Note: Need to install wheel locally since PaddlePaddle container
# already contains APT install.
pip install pydantic
pip install --user wheel==0.44.0
cd $TE_PATH
pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle
VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
# Core wheel.
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
python -m wheel unpack dist/*
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
python -m wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel
pip install dist/*.whl --no-deps
cd transformer_engine/paddle
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
pip install dist/*
python $TE_PATH/tests/paddle/test_sanity_import.py
......@@ -11,11 +11,10 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
......
......@@ -8,8 +8,8 @@ set -e
pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
# pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py ### TODO Debug UB support with te.Sequential
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==8.2.1 onnxruntime==1.19.2
# Build custom ONNX Runtime operators
export CUSTOM_ORT_OPS_PATH=$TE_PATH/tests/pytorch/custom_ort_ops
bash $CUSTOM_ORT_OPS_PATH/build.sh
# Run tests
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
......@@ -12,7 +12,14 @@ pip install pytest==8.2.1
export MAX_JOBS=4
# Iterate over Flash Attention versions
FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.6.3 3.0.0b1)
sm_arch=`python -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"`
if [ $sm_arch -gt 90 ]
then
FA_versions=(2.7.3)
else
FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
fi
for fa_version in "${FA_versions[@]}"
do
......@@ -21,10 +28,10 @@ do
then
pip install flash-attn==${fa_version}
else
pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flashattn_hopper
wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py
wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py
fi
# Run tests
......
......@@ -5,6 +5,7 @@
"""Installation script."""
import os
import sys
import time
from pathlib import Path
from typing import List, Tuple
......@@ -35,14 +36,13 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1"
if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension
CMakeBuildExtension = get_build_ext(BuildExtension)
archs = cuda_archs()
class TimedBdist(bdist_wheel):
......@@ -57,7 +57,7 @@ class TimedBdist(bdist_wheel):
def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library"""
cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())]
cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
......@@ -104,13 +104,11 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch"])
test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"])
test_reqs.extend(["numpy", "torchvision", "prettytable"])
if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"])
test_reqs.extend(["numpy", "praxis"])
if "paddle" in frameworks:
install_reqs.append("paddlepaddle-gpu")
test_reqs.append("numpy")
# test_reqs.extend(["numpy", "praxis"])
test_reqs.extend(["numpy"])
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
......@@ -135,7 +133,6 @@ if __name__ == "__main__":
extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
"paddle": [f"transformer_engine_paddle=={__version__}"],
}
else:
setup_requires, install_requires, test_requires = setup_requirements()
......@@ -169,16 +166,6 @@ if __name__ == "__main__":
current_file_path / "transformer_engine",
)
)
if "paddle" in frameworks:
from build_tools.paddle import setup_paddle_extension
ext_modules.append(
setup_paddle_extension(
"transformer_engine/paddle/csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc",
current_file_path / "transformer_engine",
)
)
# Configure package
setuptools.setup(
......
......@@ -5,7 +5,11 @@
cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 90)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90)
endif()
endif()
......
......@@ -3,23 +3,33 @@
# See LICENSE for license information.
add_executable(test_operator
test_cast.cu
test_cast_dbias.cu
test_cast_dbias_dgelu.cu
test_cast_gated_swiglu.cu
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
test_cast_transpose.cu
test_cast_mxfp8.cu
test_dequantize_mxfp8.cu
test_transpose.cu
test_cast_transpose.cu
test_cast_transpose_dbias.cu
test_cast_transpose_dbias_dgelu.cu
test_cast_transpose_dgeglu.cu
test_act.cu
test_normalization.cu
test_normalization_mxfp8.cu
test_multi_cast_transpose.cu
test_multi_padding.cu
test_causal_softmax.cu
test_swizzle.cu
../test_common.cu)
find_package(OpenMP REQUIRED)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS})
target_compile_options(test_operator PRIVATE -O2)
target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX)
target_compile_options(test_operator PRIVATE -O2 -fopenmp)
include(GoogleTest)
gtest_discover_tests(test_operator)
gtest_discover_tests(test_operator DISCOVERY_TIMEOUT 600)
......@@ -21,58 +21,6 @@
using namespace transformer_engine;
namespace {
// forward
float gelu(const float x) {
return 0.5f * x * (1.0f + tanhf(0.79788456F * x * (1.0f + 0.044715f * x * x)));
}
float silu(const float x) {
return x / (1 + expf(-x));
}
float relu(const float x) {
return x > 0 ? x : 0;
}
float srelu(const float x) {
return x > 0 ? x * x : 0;
}
float qgelu(const float x) {
return x / (1 + expf(-1.702f * x));
}
// backward
float dgelu(const float x) {
const float tanh_out = tanhf(0.79788456f * x * (1.f + 0.044715f * x * x));
return 0.5f * x * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) +
0.5f * (1.f + tanh_out);
}
float dsilu(const float x) {
const float sigmoid = 1.f / (1 + expf(-x));
return x * sigmoid * (1.f - sigmoid) + sigmoid;
}
float drelu(const float x) {
return x > 0.f ? 1.f : 0.f;
}
float dsrelu(const float x) {
return fmaxf(2.f * x, 0.f);
}
float dqgelu(const float x) {
const float sigmoid = 1.f / (1 + expf(-1.702f * x));
return 1.702f * x * sigmoid * (1.f - sigmoid) + sigmoid;
}
} // namespace
template <float (*act)(const float), typename IT, typename OT, typename CT>
void compute_ref_act_cast(const IT *input_h,
OT *output_h,
......@@ -82,6 +30,7 @@ void compute_ref_act_cast(const IT *input_h,
const size_t H) {
CT amax = 0.;
#pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]);
......@@ -101,6 +50,7 @@ void compute_ref_dact_cast(const IT *input_h,
const size_t N,
const size_t H) {
using CT = float;
#pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]);
......@@ -118,6 +68,7 @@ void compute_ref_glu_act_cast(const IT *input_h, OT *output_h, const CT scale, C
const int col = H * 2;
#pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT gelu_elt = static_cast<CT>(input_h[i * col + j]);
......@@ -139,6 +90,7 @@ void compute_ref_dglu_act_cast(const IT *input_h, const IT *grad_h, OT *output_h
const int col = H * 2;
using CT = float;
#pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT grad = static_cast<CT>(grad_h[i * H + j]);
......@@ -164,10 +116,10 @@ void performTest(const size_t N, const size_t H) {
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor input({ N, H }, itype);
Tensor output({ N, H }, otype);
Tensor igrad({ N, H }, itype);
Tensor ograd({ N, H }, itype);
Tensor input("input", { N, H }, itype);
Tensor output("output", { N, H }, otype);
Tensor igrad("igrad", { N, H }, itype);
Tensor ograd("ograd", { N, H }, itype);
fillUniform(&input);
fillUniform(&ograd);
......@@ -179,7 +131,7 @@ void performTest(const size_t N, const size_t H) {
nvte_act(input.data(), output.data(), 0);
float ref_amax;
compute_ref_act_cast<ref_act>(input.cpu_dptr<IType>(), ref_output.get(),
compute_ref_act_cast<ref_act>(input.rowwise_cpu_dptr<IType>(), ref_output.get(),
output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize();
......@@ -195,7 +147,7 @@ void performTest(const size_t N, const size_t H) {
nvte_dact(ograd.data(), input.data(), igrad.data(), 0);
compute_ref_dact_cast<ref_dact>(input.cpu_dptr<IType>(), ograd.cpu_dptr<IType>(),
compute_ref_dact_cast<ref_dact>(input.rowwise_cpu_dptr<IType>(), ograd.rowwise_cpu_dptr<IType>(),
ref_igrad.get(), N, H);
cudaDeviceSynchronize();
......@@ -219,10 +171,10 @@ void performTestGLU(const size_t N, const size_t H) {
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor input({N, H * 2}, itype);
Tensor output({N, H}, otype);
Tensor igrad({ N, H * 2 }, itype);
Tensor ograd({ N, H }, itype);
Tensor input("input", {N, H * 2}, itype);
Tensor output("output", {N, H}, otype);
Tensor igrad("igrad", { N, H * 2 }, itype);
Tensor ograd("ograd", { N, H }, itype);
fillUniform(&input);
fillUniform(&ograd);
......@@ -234,7 +186,7 @@ void performTestGLU(const size_t N, const size_t H) {
nvte_act(input.data(), output.data(), 0);
float ref_amax;
compute_ref_glu_act_cast<ref_act>(input.cpu_dptr<IType>(), ref_output.get(),
compute_ref_glu_act_cast<ref_act>(input.rowwise_cpu_dptr<IType>(), ref_output.get(),
output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize();
......@@ -242,15 +194,19 @@ void performTestGLU(const size_t N, const size_t H) {
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
auto [atol, rtol] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol, rtol);
if (output.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
const float ref_scale = 1.f / output.scale();
compareResults("scale_inv", *output.rowwise_cpu_scale_inv_ptr<float>(), ref_scale, atol, rtol);
}
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol);
nvte_dact(ograd.data(), input.data(), igrad.data(), 0);
compute_ref_dglu_act_cast<ref_dact, ref_act>(input.cpu_dptr<IType>(), ograd.cpu_dptr<IType>(),
compute_ref_dglu_act_cast<ref_dact, ref_act>(input.rowwise_cpu_dptr<IType>(), ograd.rowwise_cpu_dptr<IType>(),
ref_igrad.get(), N, H);
cudaDeviceSynchronize();
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const InputType *data, OutputType *output_c,
const size_t size,
float *amax, float scale) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < size; ++i) {
compute_t current = static_cast<compute_t>(data[i]);
current_max = fmaxf(current_max, fabsf(current));
output_c[i] = OutputType(scale * current);
}
*amax = current_max;
}
template <typename InputType, typename OutputType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
const size_t full_size = product(shape);
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", shape, itype);
Tensor output_c("output_c", shape, otype);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(full_size);
fillUniform(&input);
setRandomScale(&output_c);
nvte_quantize(input.data(), output_c.data(), 0);
float ref_amax;
compute_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output_c.get(),
full_size, &ref_amax, output_c.scale());
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
}
std::vector<std::vector<size_t>> test_cases = {
{16},
{16000},
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{65536, 160},
{16384, 1616},
{1, 128},
{1, 1296},
{1, 16},
{5, 160},
{5, 4, 3, 160},
{217, 256},
};
} // namespace
class CastTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(CastTestSuite, TestCast) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename IT, typename OT, typename CT>
void compute_ref_cast_dbias(const IT *input_h,
const CT scale,
OT *output_c_h,
CT *amax_h,
IT *dbias_h,
const size_t N,
const size_t H) {
CT amax = 0.;
std::vector<CT> acc_dbias(H, 0.);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]);
// update amax
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
output_c_h[i * H + j] = static_cast<OT>(scale * elt);
// dbias
acc_dbias[j] += elt;
}
}
*amax_h = amax;
for (size_t i = 0; i < H; i++) {
dbias_h[i] = static_cast<IT>(acc_dbias[i]);
}
}
template <typename IType, typename OType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
const size_t N = first_dimension(shape);
const size_t H = last_dimension(shape);
Tensor input("input", shape, itype);
Tensor output_c("output_c", shape, otype);
// dbias has the same data type with "output grad"
Tensor dbias("dbias", {H}, itype);
fillUniform(&input);
setRandomScale(&output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_output_dbias = std::make_unique<IType[]>(H);
CType ref_amax;
compute_ref_cast_dbias(input.rowwise_cpu_dptr<IType>(),
output_c.scale(),
ref_output_c.get(),
&ref_amax,
ref_output_dbias.get(),
N, H);
Tensor workspace;
nvte_quantize_dbias(input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias(input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
rtol_dbias *= 4;
compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
std::vector<std::vector<size_t>> test_cases = {
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{65536, 160},
{16384, 1616},
{1, 128},
{1, 1296},
{1, 16},
{5, 160},
{5, 4, 3, 160},
{217, 256},
};
} // namespace;
class CastDBiasTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(CastDBiasTestSuite, TestCastDBias) {
using namespace transformer_engine;
using namespace test;
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastDBiasTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastDBiasTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <typename IT, typename OT, typename CT>
void compute_ref_cast_dbias_dgelu(const IT *input,
const IT *grad,
const CT scale,
OT *output_c,
CT *amax_h,
IT *dbias,
const size_t N,
const size_t H) {
CT amax = 0.;
std::vector<CT> acc_dbias(H, 0.);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT in_elt = static_cast<CT>(input[i * H + j]);
const CT in_grad = static_cast<CT>(grad[i * H + j]);
const CT elt = in_grad * static_cast<float>(dgelu(static_cast<float>(in_elt)));
const CT elt_abs = std::abs(elt);
// update amax
if (elt_abs > amax) {
amax = elt_abs;
}
output_c[i * H + j] = static_cast<OT>(scale * elt);
// dbias
acc_dbias[j] += elt;
}
}
*amax_h = amax;
for (size_t i = 0; i < H; i++) {
dbias[i] = static_cast<IT>(acc_dbias[i]);
}
}
template <typename IType, typename OType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
const size_t N = first_dimension(shape);
const size_t H = last_dimension(shape);
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype);
// dbias has the same data type with "output grad"
Tensor dbias("dbias", {H}, itype);
fillUniform(&input);
fillUniform(&grad);
setRandomScale(&output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_output_dbias = std::make_unique<IType[]>(H);
CType ref_amax;
compute_ref_cast_dbias_dgelu(input.rowwise_cpu_dptr<IType>(),
grad.rowwise_cpu_dptr<IType>(),
output_c.scale(),
ref_output_c.get(),
&ref_amax,
ref_output_dbias.get(),
N, H);
Tensor workspace;
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
rtol_dbias *= 4;
compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
std::vector<std::vector<size_t>> test_cases = {
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{65536, 160},
{16384, 1616},
{1, 128},
{1, 1296},
{1, 16},
{5, 160},
{5, 4, 3, 160},
{217, 256},
};
} // namespace;
class CastDBiasDGeluTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(CastDBiasDGeluTestSuite, TestCastDBiasDgelu) {
using namespace transformer_engine;
using namespace test;
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastDBiasDGeluTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastDBiasDGeluTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <omp.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <typename IType, typename OType>
void compute_ref_cast_dgated_swiglu(const IType * const grad,
const IType * const input,
const float scale,
OType * const output,
float * const amax_ptr,
const size_t rows,
const size_t cols) {
float amax = 0;
const size_t stride = cols * 2;
#pragma omp parallel for reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < rows; i++) {
for (size_t j = 0; j < cols; j++) {
float grad_elt = static_cast<float>(grad[i * cols + j]);
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
float after_dgate = grad_elt * silu(silu_elt);
if (abs(after_dsilu) > amax) { amax = abs(after_dsilu); }
if (abs(after_dgate) > amax) { amax = abs(after_dgate); }
output[i * stride + j] = static_cast<OType>(scale * after_dsilu);
output[i * stride + cols + j] = static_cast<OType>(scale * after_dgate);
}
}
*amax_ptr = amax;
}
template <typename IType, typename OType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
std::vector<size_t> input_shape = shape;
input_shape[input_shape.size() - 1] *= 2;
const size_t input_size = product(input_shape);
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);
Tensor grad("grad", shape, itype);
Tensor input("input", input_shape, itype);
Tensor output_c("output_c", input_shape, otype);
fillUniform(&grad);
fillUniform(&input);
setRandomScale(&output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(input_size);
nvte_dswiglu(grad.data(), input.data(), output_c.data(), 0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax;
compute_ref_cast_dgated_swiglu(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(),
output_c.scale(),
ref_output_c.get(),
&ref_amax,
rows,
cols);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
}
std::vector<std::vector<size_t>> test_cases = {
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{217, 256},
{1296},
{5, 4, 3, 160},
};
} // namespace
class CastSwiGLUTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::vector<size_t>>> {};
TEST_P(CastSwiGLUTestSuite, TestCastSwiGLU) {
using namespace transformer_engine;
using namespace test;
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
if (size.back() % 32 != 0) {
GTEST_SKIP();
}
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType, performTest<InputType, OutputType>(size);););
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest, CastSwiGLUTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastSwiGLUTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
using namespace test;
namespace {
enum ProcessingMethod {
CAST_ONLY,
CAST_DBIAS,
CAST_DBIAS_DACT,
CAST_DACT,
CAST_ACT
};
enum ActivationType {
Identity,
GeLU,
SiLU,
ReLU,
QGeLU,
SReLU
};
template <typename InputType, typename OutputType, float (*OP)(const float)>
void scale_block(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_c,
float* dbias,
fp8e8m0* output_scales,
const size_t scale_idx,
const size_t i_min,
const size_t i_max,
const size_t j_min,
const size_t j_max,
const size_t cols) {
float amax = 0.0f;
// Find the absolute maximum value in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input
elt = static_cast<float>(grad[idx]);
}
if (processing_method != ProcessingMethod::CAST_ONLY
&& processing_method != ProcessingMethod::CAST_DBIAS) {
elt = OP(elt);
}
if (processing_method == ProcessingMethod::CAST_DACT ||
processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]);
}
dbias[j] += elt;
if (isinf(elt) || isnan(elt)) {
continue;
}
amax = std::max(amax, std::abs(elt));
}
}
const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits<OutputType>::max_reciprocal());
const float scale_reciprocal = exp2f_rcp(biased_exponent);
output_scales[scale_idx] = biased_exponent;
// Quantize elements in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input
elt = static_cast<float>(grad[idx]);
}
if (processing_method != ProcessingMethod::CAST_ONLY
&& processing_method != ProcessingMethod::CAST_DBIAS) {
elt = OP(elt);
}
if (processing_method == ProcessingMethod::CAST_DACT ||
processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]);
}
output_c[idx] = static_cast<OutputType>(elt * scale_reciprocal);
}
}
}
template <typename InputType, typename OutputType, float (*OP)(const float)>
void compute_ref_x1(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_c,
fp8e8m0* output_scales,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride)
{
std::vector<float> output_dbias_fp32(cols, 0);
const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y;
const size_t blocks_X = (cols + block_size_X - 1) / block_size_X;
for (size_t ii = 0; ii < blocks_Y; ++ii) {
const size_t i_min = ii * block_size_Y;
const size_t i_max = std::min((ii + 1) * block_size_Y, rows);
for (size_t jj = 0; jj < blocks_X; ++jj) {
const size_t j_min = jj * block_size_X;
const size_t j_max = std::min((jj + 1) * block_size_X, cols);
const size_t scale_idx = ii * scales_stride + jj;
scale_block<InputType, OutputType, OP>(
processing_method, input, grad, output_c, output_dbias_fp32.data(),
output_scales, scale_idx, i_min, i_max, j_min, j_max, cols);
}
}
for (size_t j = 0; j < cols; ++j) {
output_dbias[j] = static_cast<InputType>(output_dbias_fp32[j]);
}
}
template <typename InputType, typename OutputType, float (*OP)(const float)>
void compute_ref_x2(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_rowwise,
OutputType* output_colwise,
fp8e8m0* scales_rowwise,
fp8e8m0* scales_colwise,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise) {
compute_ref_x1<InputType, OutputType, OP>(
processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias,
rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<InputType, OutputType, OP>(
processing_method, input, grad, output_colwise, scales_colwise, output_dbias,
rows, cols, block_size_Y, 1, scales_stride_colwise);
}
/**
* Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* OR
* 2) Scaled columns + column-wise scaling factors
*/
template <typename InputType, typename OutputType, float (*OP)(const float)>
void performTest_x1(const ProcessingMethod processing_method,
const std::vector<size_t>& shape,
const bool rowwise,
const bool colwise,
InputsFillCase fill_case) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);
if (shape.size() < 2 && colwise) {
GTEST_SKIP();
}
const size_t block_size_rows = rowwise ? 1 : 32;
const size_t block_size_cols = colwise ? 1 : 32;
const std::array<size_t,4> scale_dims = get_scale_tensor_dims(rows, cols, block_size_rows,
block_size_cols);
const size_t unpadded_blocks_Y = scale_dims[0];
const size_t unpadded_blocks_X = scale_dims[1];
const size_t blocks_Y = scale_dims[2];
const size_t blocks_X = scale_dims[3];
const size_t scales_stride = blocks_X;
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING);
Tensor output_dbias("output_dbias", { cols }, itype);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<InputType[]> ref_output_dbias = std::make_unique<InputType[]>(cols);
std::unique_ptr<fp8e8m0[]> ref_output_scales = std::make_unique<fp8e8m0[]>(blocks_Y * blocks_X);
fillCase<EncodingType>(&input, fill_case);
fillUniform(&grad);
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_quantize(input.data(), output_c.data(), 0);
break;
}
case ProcessingMethod::CAST_DBIAS: {
nvte_quantize_dbias(grad.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias(grad.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
break;
}
case ProcessingMethod::CAST_DBIAS_DACT: {
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
output_dbias.data(),
workspace.data(),
0);
break;
}
case ProcessingMethod::CAST_DACT: {
nvte_dgelu(grad.data(), input.data(), output_c.data(), 0);
break;
}
case ProcessingMethod::CAST_ACT: {
nvte_gelu(input.data(), output_c.data(), 0);
break;
}
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
compute_ref_x1<InputType, OutputType, OP>(processing_method,
input.rowwise_cpu_dptr<InputType>(),
grad.rowwise_cpu_dptr<InputType>(),
ref_output_c.get(),
ref_output_scales.get(),
ref_output_dbias.get(),
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol);
const uint8_t * const gpu_scales_ptr = rowwise
? output_c.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output_c.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
rtol_dbias *= sqrt(static_cast<double>(rows)) ;
} else {
rtol_dbias *= 4;
}
compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
}
/**
* Scaling along both dimensions (rows and columns)
* Produces two sets of scaled output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* AND
* 2) Scaled columns + column-wise scaling factors
*/
template <typename InputType, typename OutputType, float (*OP)(const float)>
void performTest_x2(const ProcessingMethod processing_method,
const std::vector<size_t>& shape,
const size_t block_size_rows,
const size_t block_size_cols,
InputsFillCase fill_case) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
if (shape.size() < 2) {
GTEST_SKIP();
}
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);
const std::array<size_t,4> scale_dims_rowwise = get_scale_tensor_dims(rows, cols, 1, 32);
const std::array<size_t,4> scale_dims_colwise = get_scale_tensor_dims(rows, cols, 32, 1);
const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0];
const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1];
const size_t blocks_Y_rowwise = scale_dims_rowwise[2];
const size_t blocks_X_rowwise = scale_dims_rowwise[3];
const size_t scales_stride_rowwise = blocks_X_rowwise;
const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0];
const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1];
const size_t blocks_Y_colwise = scale_dims_colwise[2];
const size_t blocks_X_colwise = scale_dims_colwise[3];
const size_t scales_stride_colwise = blocks_X_colwise;
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output("output", shape, otype, true, true, NVTE_MXFP8_1D_SCALING);
Tensor output_dbias("output_dbias", { cols }, itype);
std::unique_ptr<OutputType[]> ref_output_c_rowwise = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<OutputType[]> ref_output_c_colwise = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<fp8e8m0[]> ref_scales_rowwise = std::make_unique<fp8e8m0[]>(blocks_Y_rowwise * blocks_X_rowwise);
std::unique_ptr<fp8e8m0[]> ref_scales_colwise = std::make_unique<fp8e8m0[]>(blocks_Y_colwise * blocks_X_colwise);
std::unique_ptr<InputType[]> ref_output_dbias = std::make_unique<InputType[]>(cols);
fillCase<EncodingType>(&input, fill_case);
fillUniform(&grad);
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_quantize(input.data(), output.data(), 0);
break;
}
case ProcessingMethod::CAST_DBIAS: {
nvte_quantize_dbias(grad.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias(grad.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
break;
}
case ProcessingMethod::CAST_DBIAS_DACT: {
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output.data(),
output_dbias.data(),
workspace.data(),
0);
break;
}
case ProcessingMethod::CAST_DACT: {
nvte_dgelu(grad.data(), input.data(), output.data(), 0);
break;
}
case ProcessingMethod::CAST_ACT: {
nvte_gelu(input.data(), output.data(), 0);
break;
}
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
compute_ref_x2<InputType, OutputType, OP>(processing_method,
input.rowwise_cpu_dptr<InputType>(),
grad.rowwise_cpu_dptr<InputType>(),
ref_output_c_rowwise.get(),
ref_output_c_colwise.get(),
ref_scales_rowwise.get(),
ref_scales_colwise.get(),
ref_output_dbias.get(),
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride_rowwise,
scales_stride_colwise);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol);
compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol);
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise);
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise);
if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
rtol_dbias *= sqrt(static_cast<double>(rows)) ;
} else {
rtol_dbias *= 4;
}
compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
}
std::vector<std::vector<size_t>> matrix_sizes = {
{1, 16},
{16, 48},
{65, 96},
{128, 128},
{256, 256},
{993, 512},
{256, 65536},
{2048, 6144},
{16384, 128},
{32768, 160},
{4096, 1632},
{1024},
{8, 32, 1024},
{16, 8, 4, 512},
};
std::vector<std::pair<size_t, size_t>> block_sizes = {
{1, 32},
{32, 1},
{32, 32},
};
std::vector<InputsFillCase> input_scenarios = {
InputsFillCase::uniform,
// InputsFillCase::zeros,
// InputsFillCase::zero_to_minNorm,
// InputsFillCase::minNorm_to_maxNorm,
// InputsFillCase::maxNorm_to_inf
};
std::vector<ProcessingMethod> processing_methods = {
ProcessingMethod::CAST_ONLY,
ProcessingMethod::CAST_DBIAS,
ProcessingMethod::CAST_DBIAS_DACT,
ProcessingMethod::CAST_DACT,
ProcessingMethod::CAST_ACT,
};
// Only GeLU activation tests are supported
std::vector<ActivationType> Activation_types = {
ActivationType::Identity,
ActivationType::GeLU,
// ActivationType::SiLU,
// ActivationType::ReLU,
// ActivationType::QGeLU,
// ActivationType::SReLU,
};
} // namespace
class FusedCastMXFP8TestSuite : public ::testing::TestWithParam
<std::tuple<ProcessingMethod,
ActivationType,
std::vector<size_t>,
std::pair<size_t, size_t>,
transformer_engine::DType,
transformer_engine::DType,
InputsFillCase>> {};
#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \
case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \
case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \
case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \
case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \
}
#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \
case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \
case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \
case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \
case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \
}
TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) {
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const ProcessingMethod processing_method = std::get<0>(GetParam());
const ActivationType Act_type = std::get<1>(GetParam());
const auto matrix_size = std::get<2>(GetParam());
const auto block_size = std::get<3>(GetParam());
const DType input_type = std::get<4>(GetParam());
const DType output_type = std::get<5>(GetParam());
const InputsFillCase fill_case = std::get<6>(GetParam());
// Skips non Act tests if the Activation type is not an identity
if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
&& Act_type != ActivationType::Identity) {
GTEST_SKIP();
}
// Skips Act tests if the Activation is an identity
if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
|| processing_method == ProcessingMethod::CAST_DACT
|| processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) {
GTEST_SKIP();
}
const bool rowwise = block_size.second != 1;
const bool colwise = block_size.first != 1;
if (processing_method == ProcessingMethod::CAST_ACT) {
// Forward activations
ACT_FUNC_SWITCH(Act_type, OP,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType, OP>(
processing_method, matrix_size,
rowwise, colwise, fill_case);
} else {
performTest_x2<InputType, OutputType, OP>(
processing_method, matrix_size,
block_size.first, block_size.second, fill_case);
}
);
);
);
} else {
DACT_FUNC_SWITCH(Act_type, OP,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType, OP>(
processing_method, matrix_size,
rowwise, colwise, fill_case);
} else {
performTest_x2<InputType, OutputType, OP>(
processing_method, matrix_size,
block_size.first, block_size.second, fill_case);
}
);
);
);
}
}
std::string to_string(const ProcessingMethod method) {
switch (method) {
case ProcessingMethod::CAST_ONLY: return "CAST_ONLY";
case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS";
case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT";
case ProcessingMethod::CAST_DACT: return "CAST_DACT";
case ProcessingMethod::CAST_ACT: return "CAST_ACT";
default: return "";
}
}
std::string to_string(const ActivationType Act_type) {
switch (Act_type) {
case ActivationType::Identity: return "Identity";
case ActivationType::GeLU: return "GeLU";
case ActivationType::SiLU: return "SiLU";
case ActivationType::ReLU: return "ReLU";
case ActivationType::QGeLU: return "QGeLU";
case ActivationType::SReLU: return "SReLU";
default: return "";
}
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
FusedCastMXFP8TestSuite,
::testing::Combine(
::testing::ValuesIn(processing_methods),
::testing::ValuesIn(Activation_types),
::testing::ValuesIn(matrix_sizes),
::testing::ValuesIn(block_sizes),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios)),
[](const testing::TestParamInfo<FusedCastMXFP8TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param)) + "X" +
to_string(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
name += "X" + std::to_string(std::get<3>(info.param).first) +
"X" + std::to_string(std::get<3>(info.param).second) +
"X" + test::typeName(std::get<4>(info.param)) +
"X" + test::typeName(std::get<5>(info.param)) +
"X" + test::caseName(std::get<6>(info.param));
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <bool IS_DGATED, typename IType, typename OType>
void scale_block(const IType* grad,
const IType* input,
OType* output,
fp8e8m0* output_scales,
const size_t scale_idx,
const size_t scale_idx_gate,
float& thread_amax,
const size_t i_min,
const size_t i_max,
const size_t j_min,
const size_t j_max,
const size_t cols) {
float block_amax = 0.0f;
float block_amax_gate = 0.0f;
const size_t stride = cols * 2;
// Find the absolute maximum value in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
float gated_amax_act = 0;
float gated_amax_gate = 0;
if constexpr (IS_DGATED) {
const float grad_elt = static_cast<float>(grad[i * cols + j]);
const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
const float after_dgate = silu(silu_elt) * grad_elt;
gated_amax_act = abs(after_dsilu);
gated_amax_gate = abs(after_dgate);
} else {
const float after_silu = silu(silu_elt) * gate_elt;
gated_amax_act = abs(after_silu);
}
if (gated_amax_act > block_amax) { block_amax = gated_amax_act; }
if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; }
}
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax *
Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal = exp2f_rcp(biased_exponent);
output_scales[scale_idx] = biased_exponent;
float scale_reciprocal_gate = 1;
if constexpr (IS_DGATED) {
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate *
Quantized_Limits<OType>::max_reciprocal());
scale_reciprocal_gate = exp2f_rcp(biased_exponent);
output_scales[scale_idx_gate] = biased_exponent;
}
// Quantize elements in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
if constexpr (IS_DGATED) {
const float grad_elt = static_cast<float>(grad[i * cols + j]);
const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
const float after_dgate = silu(silu_elt) * grad_elt;
output[i * stride + j] = static_cast<OType>(after_dsilu * scale_reciprocal);
output[i * stride + cols + j] = static_cast<OType>(after_dgate *
scale_reciprocal_gate);
} else {
const float after_silu = silu(silu_elt) * gate_elt;
output[i * cols + j] = static_cast<OType>(after_silu * scale_reciprocal);
}
}
}
thread_amax = std::max(thread_amax, block_amax);
thread_amax = std::max(thread_amax, block_amax_gate);
}
template <bool IS_DGATED, typename IType, typename OType>
void compute_ref_x1(const IType* grad,
const IType* input,
OType* output,
fp8e8m0* output_scales,
float& ref_amax,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride) {
const size_t tile_size_Y = std::max(32lu, block_size_Y);
const size_t tile_size_X = std::max(64lu, block_size_X);
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
float amax = 0;
#pragma omp parallel reduction(max: amax) proc_bind(spread)
{
float thread_amax = 0;
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii;
const size_t block_offset_Y = ii * block_size_Y;
const size_t i_min = tile_offset_Y + block_offset_Y;
if (i_min >= rows) continue;
const size_t i_max = std::min(i_min + block_size_Y, rows);
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) {
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
const size_t block_offset_X = jj * block_size_X;
const size_t j_min = tile_offset_X + block_offset_X;
if (j_min >= cols) continue;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X;
const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X +
cols / block_size_X;
scale_block<IS_DGATED, IType, OType>(
grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate,
thread_amax, i_min, i_max, j_min, j_max, cols);
}
}
}
if (thread_amax > amax) {
amax = thread_amax;
}
}
ref_amax = amax;
}
template <bool IS_DGATED, typename IType, typename OType>
void compute_ref_x2(const IType* grad,
const IType* input,
OType* output_rowwise,
OType* output_colwise,
fp8e8m0* scales_rowwise,
fp8e8m0* scales_colwise,
float& ref_amax,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise) {
compute_ref_x1<IS_DGATED, IType, OType>(
grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<IS_DGATED, IType, OType>(
grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise);
}
/**
* Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* OR
* 2) Scaled columns + column-wise scaling factors
*/
template <bool IS_DGATED, typename IType, typename OType>
void performTest_x1(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols,
InputsFillCase fill_case) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
const bool rowwise = (block_size_rows == 1) && (block_size_cols == 32);
const bool colwise = (block_size_rows == 32) && (block_size_cols == 1);
NVTE_CHECK(rowwise || colwise);
// std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl;
// std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl;
// std::cout << "blocks_Y: " << blocks_Y << std::endl;
// std::cout << "blocks_X: " << blocks_X << std::endl;
// std::cout << "scales_stride: " << scales_stride << std::endl;
Tensor grad("grad", { rows, cols }, itype);
Tensor input("input", { rows, cols * 2 }, itype);
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
const std::array<size_t,4> scale_dims = get_scale_tensor_dims(rows, output_cols, block_size_rows,
block_size_cols);
const size_t unpadded_blocks_Y = scale_dims[0];
const size_t unpadded_blocks_X = scale_dims[1];
const size_t blocks_Y = scale_dims[2];
const size_t blocks_X = scale_dims[3];
const size_t scales_stride = blocks_X;
Tensor output("output", std::vector<size_t>{ rows, output_cols }, otype,
rowwise, colwise, NVTE_MXFP8_1D_SCALING);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(rows * output_cols);
std::unique_ptr<fp8e8m0[]> ref_output_scales = std::make_unique<fp8e8m0[]>(blocks_Y * blocks_X);
for (size_t i = 0; i < blocks_Y * blocks_X; ++i) {
ref_output_scales[i] = 0;
}
// fillCase<EncodingType>(&grad, fill_case);
if constexpr (IS_DGATED) {
fillUniform(&grad);
}
fillUniform(&input);
if constexpr (IS_DGATED) {
nvte_dswiglu(grad.data(), input.data(), output.data(), 0);
} else {
nvte_swiglu(input.data(), output.data(), 0);
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax = 0;
compute_ref_x1<IS_DGATED, IType, OType>(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(),
ref_output.get(),
ref_output_scales.get(),
ref_amax,
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride);
auto [atol, rtol] = getTolerances(otype);
compareResults("output", output, ref_output.get(), rowwise, atol, rtol);
const uint8_t * const gpu_scales_ptr = rowwise
? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
if (rowwise) {
compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
} else {
compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
}
}
/**
* Scaling along both dimensions (rows and columns)
* Produces two sets of scaled output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* AND
* 2) Scaled columns + column-wise scaling factors
*/
template <bool IS_DGATED, typename IType, typename OType>
void performTest_x2(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols,
InputsFillCase fill_case) {
using namespace test;
using EncodingType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor grad("grad", { rows, cols }, itype);
Tensor input("input", { rows, cols * 2 }, itype);
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
const std::array<size_t,4> scale_dims_rowwise = get_scale_tensor_dims(rows, output_cols, 1, 32);
const std::array<size_t,4> scale_dims_colwise = get_scale_tensor_dims(rows, output_cols, 32, 1);
const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0];
const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1];
const size_t blocks_Y_rowwise = scale_dims_rowwise[2];
const size_t blocks_X_rowwise = scale_dims_rowwise[3];
const size_t scales_stride_rowwise = blocks_X_rowwise;
const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0];
const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1];
const size_t blocks_Y_colwise = scale_dims_colwise[2];
const size_t blocks_X_colwise = scale_dims_colwise[3];
const size_t scales_stride_colwise = blocks_X_colwise;
Tensor output("output", std::vector<size_t>{ rows, output_cols }, otype,
true, true, NVTE_MXFP8_1D_SCALING);
std::unique_ptr<OType[]> ref_output_rowwise = std::make_unique<OType[]>(rows * output_cols);
std::unique_ptr<OType[]> ref_output_colwise = std::make_unique<OType[]>(rows * output_cols);
std::unique_ptr<fp8e8m0[]> ref_scales_rowwise = std::make_unique<fp8e8m0[]>(blocks_Y_rowwise * blocks_X_rowwise);
std::unique_ptr<fp8e8m0[]> ref_scales_colwise = std::make_unique<fp8e8m0[]>(blocks_Y_colwise * blocks_X_colwise);
for (size_t i = 0; i < blocks_Y_rowwise * blocks_X_rowwise; ++i) {
ref_scales_rowwise[i] = 0;
}
for (size_t i = 0; i < blocks_Y_colwise * blocks_X_colwise; ++i) {
ref_scales_colwise[i] = 0;
}
// fillCase<EncodingType>(&grad, fill_case);
if constexpr (IS_DGATED) {
fillUniform(&grad);
}
fillUniform(&input);
if constexpr (IS_DGATED) {
nvte_dswiglu(grad.data(), input.data(), output.data(), 0);
} else {
nvte_swiglu(input.data(), output.data(), 0);
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax = 0;
compute_ref_x2<IS_DGATED, IType, OType>(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(),
ref_output_rowwise.get(),
ref_output_colwise.get(),
ref_scales_rowwise.get(),
ref_scales_colwise.get(),
ref_amax,
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride_rowwise,
scales_stride_colwise);
auto [atol, rtol] = getTolerances(otype);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol);
compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol);
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise);
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise);
}
std::vector<std::pair<size_t, size_t>> matrix_sizes = {
{1, 32},
{16, 64},
{65, 96},
{128, 128},
{256, 256},
{993, 512},
{768, 1024},
{65536, 128},
{16384, 1632},
};
std::vector<std::pair<size_t, size_t>> block_sizes = {
{1, 32},
{32, 1},
{32, 32},
};
std::vector<InputsFillCase> input_scenarios = {
InputsFillCase::uniform,
// InputsFillCase::zeros,
// InputsFillCase::zero_to_minNorm,
// InputsFillCase::minNorm_to_maxNorm,
// InputsFillCase::maxNorm_to_inf
};
std::vector<bool> is_dgated_op = {
true,
false
};
} // namespace
class CastMXFP8_GatedActTestSuite : public ::testing::TestWithParam
<std::tuple<std::pair<size_t, size_t>,
std::pair<size_t, size_t>,
transformer_engine::DType,
transformer_engine::DType,
InputsFillCase,
bool>> {};
TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) {
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const auto matrix_size = std::get<0>(GetParam());
const auto block_size = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
const DType output_type = std::get<3>(GetParam());
const InputsFillCase fill_case = std::get<4>(GetParam());
const bool IS_DGATED = std::get<5>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType,
if (block_size.first == 1 || block_size.second == 1) {
if (IS_DGATED) {
performTest_x1<true, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
} else {
performTest_x1<false, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
}
} else {
if (IS_DGATED) {
performTest_x2<true, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
} else {
performTest_x2<false, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
}
}
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastMXFP8_GatedActTestSuite,
::testing::Combine(
::testing::ValuesIn(matrix_sizes),
::testing::ValuesIn(block_sizes),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios),
::testing::ValuesIn(is_dgated_op)),
[](const testing::TestParamInfo<CastMXFP8_GatedActTestSuite::ParamType>& info) {
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" +
std::to_string(std::get<1>(info.param).first) + "X" +
std::to_string(std::get<1>(info.param).second) + "X" +
test::typeName(std::get<2>(info.param)) + "X" +
test::typeName(std::get<3>(info.param)) + "X" +
test::caseName(std::get<4>(info.param)) + "X" +
(std::get<5>(info.param) ? "DGATED" : "GATED");
return name;
});
......@@ -14,7 +14,7 @@
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transpose.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
......@@ -45,36 +45,34 @@ void performTest(const size_t N, const size_t H) {
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input({ N, H }, itype);
Tensor output_c({ N, H }, otype);
Tensor output_t({ H, N }, otype);
Tensor input("input", { N, H }, itype);
Tensor output("output", { N, H }, otype, true, true);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<OutputType[]> ref_output_t = std::make_unique<OutputType[]>(N * H);
fillUniform(&input);
setRandomScale(&output_c);
output_t.shareFP8Meta(output_c);
setRandomScale(&output);
nvte_cast_transpose(input.data(), output_c.data(), output_t.data(), 0);
nvte_quantize(input.data(), output.data(), 0);
float ref_amax;
compute_ref<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output_c.get(),
compute_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output_c.get(),
ref_output_t.get(), N, H, &ref_amax,
output_c.scale());
output.scale());
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
compareResults("output_c", output, ref_output_c.get(), true, atol, rtol);
compareResults("output_t", output, ref_output_t.get(), false, atol, rtol);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment