Commit 544dd14b authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Update main branch with TE 2.0 code, update version to 2.1.0.dev0


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent e5369541
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
extension-pkg-whitelist=flash_attn_2_cuda, extension-pkg-whitelist=flash_attn_2_cuda,
torch, torch,
transformer_engine_torch, transformer_engine_torch,
transformer_engine_paddle,
transformer_engine_jax transformer_engine_jax
extension-pkg-allow-list=transformer_engine.transformer_engine_jax extension-pkg-allow-list=transformer_engine.transformer_engine_jax
......
...@@ -8,7 +8,7 @@ pip install "nltk>=3.8.2" ...@@ -8,7 +8,7 @@ pip install "nltk>=3.8.2"
pip install pytest==8.2.1 pip install pytest==8.2.1
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py
# Test without custom calls # Test without custom calls
NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
echo "Checking common API headers"
cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include
echo "Checking C++ files"
cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common
cpplint --recursive transformer_engine/paddle
fi
if [ -z "${CPP_ONLY}" ]
then
cd $TE_PATH
echo "Checking Python files"
pylint --recursive=y transformer_engine/common transformer_engine/paddle
fi
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -xe
pip install pytest==8.2.1
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/paddle
pytest -Wignore -v $TE_PATH/examples/paddle/mnist
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
# Install dependencies
# Note: Need to install wheel locally since PaddlePaddle container
# already contains APT install.
pip install pydantic
pip install --user wheel==0.44.0
cd $TE_PATH
pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle
VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
# Core wheel.
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
python -m wheel unpack dist/*
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
python -m wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel
pip install dist/*.whl --no-deps
cd transformer_engine/paddle
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
pip install dist/*
python $TE_PATH/tests/paddle/test_sanity_import.py
...@@ -11,11 +11,10 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py ...@@ -11,11 +11,10 @@ pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
......
...@@ -8,8 +8,8 @@ set -e ...@@ -8,8 +8,8 @@ set -e
pip install pytest==8.2.1 pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
# pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py ### TODO Debug UB support with te.Sequential
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==8.2.1 onnxruntime==1.19.2
# Build custom ONNX Runtime operators
export CUSTOM_ORT_OPS_PATH=$TE_PATH/tests/pytorch/custom_ort_ops
bash $CUSTOM_ORT_OPS_PATH/build.sh
# Run tests
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
...@@ -12,7 +12,14 @@ pip install pytest==8.2.1 ...@@ -12,7 +12,14 @@ pip install pytest==8.2.1
export MAX_JOBS=4 export MAX_JOBS=4
# Iterate over Flash Attention versions # Iterate over Flash Attention versions
FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.6.3 3.0.0b1) sm_arch=`python -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"`
if [ $sm_arch -gt 90 ]
then
FA_versions=(2.7.3)
else
FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
fi
for fa_version in "${FA_versions[@]}" for fa_version in "${FA_versions[@]}"
do do
...@@ -21,10 +28,10 @@ do ...@@ -21,10 +28,10 @@ do
then then
pip install flash-attn==${fa_version} pip install flash-attn==${fa_version}
else else
pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
python_path=`python -c "import site; print(site.getsitepackages()[0])"` python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flashattn_hopper mkdir -p $python_path/flashattn_hopper
wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py
fi fi
# Run tests # Run tests
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""Installation script.""" """Installation script."""
import os import os
import sys
import time import time
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import List, Tuple
...@@ -35,14 +36,13 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1" ...@@ -35,14 +36,13 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1"
if "pytorch" in frameworks: if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks: elif "jax" in frameworks:
install_and_import("pybind11[global]") install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension from pybind11.setup_helpers import build_ext as BuildExtension
CMakeBuildExtension = get_build_ext(BuildExtension) CMakeBuildExtension = get_build_ext(BuildExtension)
archs = cuda_archs()
class TimedBdist(bdist_wheel): class TimedBdist(bdist_wheel):
...@@ -57,7 +57,7 @@ class TimedBdist(bdist_wheel): ...@@ -57,7 +57,7 @@ class TimedBdist(bdist_wheel):
def setup_common_extension() -> CMakeExtension: def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library""" """Setup CMake extension for common library"""
cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())] cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert ( assert (
os.getenv("MPI_HOME") is not None os.getenv("MPI_HOME") is not None
...@@ -104,13 +104,11 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -104,13 +104,11 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
install_reqs.extend(["torch"]) install_reqs.extend(["torch"])
test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) test_reqs.extend(["numpy", "torchvision", "prettytable"])
if "jax" in frameworks: if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"]) install_reqs.extend(["jax", "flax>=0.7.1"])
test_reqs.extend(["numpy", "praxis"]) # test_reqs.extend(["numpy", "praxis"])
if "paddle" in frameworks: test_reqs.extend(["numpy"])
install_reqs.append("paddlepaddle-gpu")
test_reqs.append("numpy")
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
...@@ -135,7 +133,6 @@ if __name__ == "__main__": ...@@ -135,7 +133,6 @@ if __name__ == "__main__":
extras_require = { extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"],
"paddle": [f"transformer_engine_paddle=={__version__}"],
} }
else: else:
setup_requires, install_requires, test_requires = setup_requirements() setup_requires, install_requires, test_requires = setup_requirements()
...@@ -169,16 +166,6 @@ if __name__ == "__main__": ...@@ -169,16 +166,6 @@ if __name__ == "__main__":
current_file_path / "transformer_engine", current_file_path / "transformer_engine",
) )
) )
if "paddle" in frameworks:
from build_tools.paddle import setup_paddle_extension
ext_modules.append(
setup_paddle_extension(
"transformer_engine/paddle/csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc",
current_file_path / "transformer_engine",
)
)
# Configure package # Configure package
setuptools.setup( setuptools.setup(
......
...@@ -5,7 +5,11 @@ ...@@ -5,7 +5,11 @@
cmake_minimum_required(VERSION 3.18) cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 90) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90)
endif()
endif() endif()
......
...@@ -3,23 +3,33 @@ ...@@ -3,23 +3,33 @@
# See LICENSE for license information. # See LICENSE for license information.
add_executable(test_operator add_executable(test_operator
test_cast.cu
test_cast_dbias.cu
test_cast_dbias_dgelu.cu
test_cast_gated_swiglu.cu
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu test_qdq.cu
test_cast_transpose.cu test_cast_mxfp8.cu
test_dequantize_mxfp8.cu
test_transpose.cu test_transpose.cu
test_cast_transpose.cu
test_cast_transpose_dbias.cu test_cast_transpose_dbias.cu
test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dbias_dgelu.cu
test_cast_transpose_dgeglu.cu test_cast_transpose_dgeglu.cu
test_act.cu test_act.cu
test_normalization.cu test_normalization.cu
test_normalization_mxfp8.cu
test_multi_cast_transpose.cu test_multi_cast_transpose.cu
test_multi_padding.cu test_multi_padding.cu
test_causal_softmax.cu test_causal_softmax.cu
test_swizzle.cu
../test_common.cu) ../test_common.cu)
find_package(OpenMP REQUIRED)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS}) target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX)
target_compile_options(test_operator PRIVATE -O2) target_compile_options(test_operator PRIVATE -O2 -fopenmp)
include(GoogleTest) include(GoogleTest)
gtest_discover_tests(test_operator) gtest_discover_tests(test_operator DISCOVERY_TIMEOUT 600)
...@@ -21,58 +21,6 @@ ...@@ -21,58 +21,6 @@
using namespace transformer_engine; using namespace transformer_engine;
namespace {
// forward
float gelu(const float x) {
return 0.5f * x * (1.0f + tanhf(0.79788456F * x * (1.0f + 0.044715f * x * x)));
}
float silu(const float x) {
return x / (1 + expf(-x));
}
float relu(const float x) {
return x > 0 ? x : 0;
}
float srelu(const float x) {
return x > 0 ? x * x : 0;
}
float qgelu(const float x) {
return x / (1 + expf(-1.702f * x));
}
// backward
float dgelu(const float x) {
const float tanh_out = tanhf(0.79788456f * x * (1.f + 0.044715f * x * x));
return 0.5f * x * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) +
0.5f * (1.f + tanh_out);
}
float dsilu(const float x) {
const float sigmoid = 1.f / (1 + expf(-x));
return x * sigmoid * (1.f - sigmoid) + sigmoid;
}
float drelu(const float x) {
return x > 0.f ? 1.f : 0.f;
}
float dsrelu(const float x) {
return fmaxf(2.f * x, 0.f);
}
float dqgelu(const float x) {
const float sigmoid = 1.f / (1 + expf(-1.702f * x));
return 1.702f * x * sigmoid * (1.f - sigmoid) + sigmoid;
}
} // namespace
template <float (*act)(const float), typename IT, typename OT, typename CT> template <float (*act)(const float), typename IT, typename OT, typename CT>
void compute_ref_act_cast(const IT *input_h, void compute_ref_act_cast(const IT *input_h,
OT *output_h, OT *output_h,
...@@ -82,6 +30,7 @@ void compute_ref_act_cast(const IT *input_h, ...@@ -82,6 +30,7 @@ void compute_ref_act_cast(const IT *input_h,
const size_t H) { const size_t H) {
CT amax = 0.; CT amax = 0.;
#pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < N; i++) { for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) { for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]); CT elt = static_cast<CT>(input_h[i * H + j]);
...@@ -101,6 +50,7 @@ void compute_ref_dact_cast(const IT *input_h, ...@@ -101,6 +50,7 @@ void compute_ref_dact_cast(const IT *input_h,
const size_t N, const size_t N,
const size_t H) { const size_t H) {
using CT = float; using CT = float;
#pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t i = 0; i < N; i++) { for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) { for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]); CT elt = static_cast<CT>(input_h[i * H + j]);
...@@ -118,6 +68,7 @@ void compute_ref_glu_act_cast(const IT *input_h, OT *output_h, const CT scale, C ...@@ -118,6 +68,7 @@ void compute_ref_glu_act_cast(const IT *input_h, OT *output_h, const CT scale, C
const int col = H * 2; const int col = H * 2;
#pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < N; i++) { for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) { for (size_t j = 0; j < H; j++) {
CT gelu_elt = static_cast<CT>(input_h[i * col + j]); CT gelu_elt = static_cast<CT>(input_h[i * col + j]);
...@@ -139,6 +90,7 @@ void compute_ref_dglu_act_cast(const IT *input_h, const IT *grad_h, OT *output_h ...@@ -139,6 +90,7 @@ void compute_ref_dglu_act_cast(const IT *input_h, const IT *grad_h, OT *output_h
const int col = H * 2; const int col = H * 2;
using CT = float; using CT = float;
#pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t i = 0; i < N; i++) { for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) { for (size_t j = 0; j < H; j++) {
CT grad = static_cast<CT>(grad_h[i * H + j]); CT grad = static_cast<CT>(grad_h[i * H + j]);
...@@ -164,10 +116,10 @@ void performTest(const size_t N, const size_t H) { ...@@ -164,10 +116,10 @@ void performTest(const size_t N, const size_t H) {
DType itype = TypeInfo<IType>::dtype; DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype; DType otype = TypeInfo<OType>::dtype;
Tensor input({ N, H }, itype); Tensor input("input", { N, H }, itype);
Tensor output({ N, H }, otype); Tensor output("output", { N, H }, otype);
Tensor igrad({ N, H }, itype); Tensor igrad("igrad", { N, H }, itype);
Tensor ograd({ N, H }, itype); Tensor ograd("ograd", { N, H }, itype);
fillUniform(&input); fillUniform(&input);
fillUniform(&ograd); fillUniform(&ograd);
...@@ -179,7 +131,7 @@ void performTest(const size_t N, const size_t H) { ...@@ -179,7 +131,7 @@ void performTest(const size_t N, const size_t H) {
nvte_act(input.data(), output.data(), 0); nvte_act(input.data(), output.data(), 0);
float ref_amax; float ref_amax;
compute_ref_act_cast<ref_act>(input.cpu_dptr<IType>(), ref_output.get(), compute_ref_act_cast<ref_act>(input.rowwise_cpu_dptr<IType>(), ref_output.get(),
output.scale(), &ref_amax, N, H); output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
...@@ -195,7 +147,7 @@ void performTest(const size_t N, const size_t H) { ...@@ -195,7 +147,7 @@ void performTest(const size_t N, const size_t H) {
nvte_dact(ograd.data(), input.data(), igrad.data(), 0); nvte_dact(ograd.data(), input.data(), igrad.data(), 0);
compute_ref_dact_cast<ref_dact>(input.cpu_dptr<IType>(), ograd.cpu_dptr<IType>(), compute_ref_dact_cast<ref_dact>(input.rowwise_cpu_dptr<IType>(), ograd.rowwise_cpu_dptr<IType>(),
ref_igrad.get(), N, H); ref_igrad.get(), N, H);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
...@@ -219,10 +171,10 @@ void performTestGLU(const size_t N, const size_t H) { ...@@ -219,10 +171,10 @@ void performTestGLU(const size_t N, const size_t H) {
DType itype = TypeInfo<IType>::dtype; DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype; DType otype = TypeInfo<OType>::dtype;
Tensor input({N, H * 2}, itype); Tensor input("input", {N, H * 2}, itype);
Tensor output({N, H}, otype); Tensor output("output", {N, H}, otype);
Tensor igrad({ N, H * 2 }, itype); Tensor igrad("igrad", { N, H * 2 }, itype);
Tensor ograd({ N, H }, itype); Tensor ograd("ograd", { N, H }, itype);
fillUniform(&input); fillUniform(&input);
fillUniform(&ograd); fillUniform(&ograd);
...@@ -234,7 +186,7 @@ void performTestGLU(const size_t N, const size_t H) { ...@@ -234,7 +186,7 @@ void performTestGLU(const size_t N, const size_t H) {
nvte_act(input.data(), output.data(), 0); nvte_act(input.data(), output.data(), 0);
float ref_amax; float ref_amax;
compute_ref_glu_act_cast<ref_act>(input.cpu_dptr<IType>(), ref_output.get(), compute_ref_glu_act_cast<ref_act>(input.rowwise_cpu_dptr<IType>(), ref_output.get(),
output.scale(), &ref_amax, N, H); output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
...@@ -242,15 +194,19 @@ void performTestGLU(const size_t N, const size_t H) { ...@@ -242,15 +194,19 @@ void performTestGLU(const size_t N, const size_t H) {
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol, rtol] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", output.amax(), ref_amax, atol, rtol);
if (output.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
const float ref_scale = 1.f / output.scale();
compareResults("scale_inv", *output.rowwise_cpu_scale_inv_ptr<float>(), ref_scale, atol, rtol);
}
} }
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol); compareResults("output_gelu", output, ref_output.get(), atol, rtol);
nvte_dact(ograd.data(), input.data(), igrad.data(), 0); nvte_dact(ograd.data(), input.data(), igrad.data(), 0);
compute_ref_dglu_act_cast<ref_dact, ref_act>(input.cpu_dptr<IType>(), ograd.cpu_dptr<IType>(), compute_ref_dglu_act_cast<ref_dact, ref_act>(input.rowwise_cpu_dptr<IType>(), ograd.rowwise_cpu_dptr<IType>(),
ref_igrad.get(), N, H); ref_igrad.get(), N, H);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const InputType *data, OutputType *output_c,
const size_t size,
float *amax, float scale) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < size; ++i) {
compute_t current = static_cast<compute_t>(data[i]);
current_max = fmaxf(current_max, fabsf(current));
output_c[i] = OutputType(scale * current);
}
*amax = current_max;
}
template <typename InputType, typename OutputType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
const size_t full_size = product(shape);
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", shape, itype);
Tensor output_c("output_c", shape, otype);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(full_size);
fillUniform(&input);
setRandomScale(&output_c);
nvte_quantize(input.data(), output_c.data(), 0);
float ref_amax;
compute_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output_c.get(),
full_size, &ref_amax, output_c.scale());
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
}
std::vector<std::vector<size_t>> test_cases = {
{16},
{16000},
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{65536, 160},
{16384, 1616},
{1, 128},
{1, 1296},
{1, 16},
{5, 160},
{5, 4, 3, 160},
{217, 256},
};
} // namespace
class CastTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(CastTestSuite, TestCast) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename IT, typename OT, typename CT>
void compute_ref_cast_dbias(const IT *input_h,
const CT scale,
OT *output_c_h,
CT *amax_h,
IT *dbias_h,
const size_t N,
const size_t H) {
CT amax = 0.;
std::vector<CT> acc_dbias(H, 0.);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]);
// update amax
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
output_c_h[i * H + j] = static_cast<OT>(scale * elt);
// dbias
acc_dbias[j] += elt;
}
}
*amax_h = amax;
for (size_t i = 0; i < H; i++) {
dbias_h[i] = static_cast<IT>(acc_dbias[i]);
}
}
template <typename IType, typename OType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
const size_t N = first_dimension(shape);
const size_t H = last_dimension(shape);
Tensor input("input", shape, itype);
Tensor output_c("output_c", shape, otype);
// dbias has the same data type with "output grad"
Tensor dbias("dbias", {H}, itype);
fillUniform(&input);
setRandomScale(&output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_output_dbias = std::make_unique<IType[]>(H);
CType ref_amax;
compute_ref_cast_dbias(input.rowwise_cpu_dptr<IType>(),
output_c.scale(),
ref_output_c.get(),
&ref_amax,
ref_output_dbias.get(),
N, H);
Tensor workspace;
nvte_quantize_dbias(input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias(input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
rtol_dbias *= 4;
compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
std::vector<std::vector<size_t>> test_cases = {
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{65536, 160},
{16384, 1616},
{1, 128},
{1, 1296},
{1, 16},
{5, 160},
{5, 4, 3, 160},
{217, 256},
};
} // namespace;
class CastDBiasTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(CastDBiasTestSuite, TestCastDBias) {
using namespace transformer_engine;
using namespace test;
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastDBiasTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastDBiasTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <typename IT, typename OT, typename CT>
void compute_ref_cast_dbias_dgelu(const IT *input,
const IT *grad,
const CT scale,
OT *output_c,
CT *amax_h,
IT *dbias,
const size_t N,
const size_t H) {
CT amax = 0.;
std::vector<CT> acc_dbias(H, 0.);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT in_elt = static_cast<CT>(input[i * H + j]);
const CT in_grad = static_cast<CT>(grad[i * H + j]);
const CT elt = in_grad * static_cast<float>(dgelu(static_cast<float>(in_elt)));
const CT elt_abs = std::abs(elt);
// update amax
if (elt_abs > amax) {
amax = elt_abs;
}
output_c[i * H + j] = static_cast<OT>(scale * elt);
// dbias
acc_dbias[j] += elt;
}
}
*amax_h = amax;
for (size_t i = 0; i < H; i++) {
dbias[i] = static_cast<IT>(acc_dbias[i]);
}
}
template <typename IType, typename OType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
const size_t N = first_dimension(shape);
const size_t H = last_dimension(shape);
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype);
// dbias has the same data type with "output grad"
Tensor dbias("dbias", {H}, itype);
fillUniform(&input);
fillUniform(&grad);
setRandomScale(&output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_output_dbias = std::make_unique<IType[]>(H);
CType ref_amax;
compute_ref_cast_dbias_dgelu(input.rowwise_cpu_dptr<IType>(),
grad.rowwise_cpu_dptr<IType>(),
output_c.scale(),
ref_output_c.get(),
&ref_amax,
ref_output_dbias.get(),
N, H);
Tensor workspace;
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(),
input.data(),
output_c.data(),
dbias.data(),
workspace.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
rtol_dbias *= 4;
compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
}
std::vector<std::vector<size_t>> test_cases = {
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{65536, 160},
{16384, 1616},
{1, 128},
{1, 1296},
{1, 16},
{5, 160},
{5, 4, 3, 160},
{217, 256},
};
} // namespace;
class CastDBiasDGeluTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(CastDBiasDGeluTestSuite, TestCastDBiasDgelu) {
using namespace transformer_engine;
using namespace test;
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CastDBiasDGeluTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastDBiasDGeluTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <omp.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <typename IType, typename OType>
void compute_ref_cast_dgated_swiglu(const IType * const grad,
const IType * const input,
const float scale,
OType * const output,
float * const amax_ptr,
const size_t rows,
const size_t cols) {
float amax = 0;
const size_t stride = cols * 2;
#pragma omp parallel for reduction(max: amax) proc_bind(spread)
for (size_t i = 0; i < rows; i++) {
for (size_t j = 0; j < cols; j++) {
float grad_elt = static_cast<float>(grad[i * cols + j]);
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
float after_dgate = grad_elt * silu(silu_elt);
if (abs(after_dsilu) > amax) { amax = abs(after_dsilu); }
if (abs(after_dgate) > amax) { amax = abs(after_dgate); }
output[i * stride + j] = static_cast<OType>(scale * after_dsilu);
output[i * stride + cols + j] = static_cast<OType>(scale * after_dgate);
}
}
*amax_ptr = amax;
}
template <typename IType, typename OType>
void performTest(const std::vector<size_t>& shape) {
using namespace test;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
std::vector<size_t> input_shape = shape;
input_shape[input_shape.size() - 1] *= 2;
const size_t input_size = product(input_shape);
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);
Tensor grad("grad", shape, itype);
Tensor input("input", input_shape, itype);
Tensor output_c("output_c", input_shape, otype);
fillUniform(&grad);
fillUniform(&input);
setRandomScale(&output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(input_size);
nvte_dswiglu(grad.data(), input.data(), output_c.data(), 0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax;
compute_ref_cast_dgated_swiglu(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(),
output_c.scale(),
ref_output_c.get(),
&ref_amax,
rows,
cols);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol);
}
std::vector<std::vector<size_t>> test_cases = {
{128, 128},
{256, 256},
{768, 1024},
{256, 65536},
{2048, 12288},
{65536, 128},
{217, 256},
{1296},
{5, 4, 3, 160},
};
} // namespace
class CastSwiGLUTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::vector<size_t>>> {};
TEST_P(CastSwiGLUTestSuite, TestCastSwiGLU) {
using namespace transformer_engine;
using namespace test;
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
if (size.back() % 32 != 0) {
GTEST_SKIP();
}
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType, performTest<InputType, OutputType>(size);););
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest, CastSwiGLUTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CastSwiGLUTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
const auto& shape = std::get<2>(info.param);
for ( const auto& s: shape) {
name += "X" + std::to_string(s);
}
return name;
});
This diff is collapsed.
This diff is collapsed.
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <transformer_engine/transpose.h> #include <transformer_engine/cast.h>
#include "../test_common.h" #include "../test_common.h"
using namespace transformer_engine; using namespace transformer_engine;
...@@ -45,36 +45,34 @@ void performTest(const size_t N, const size_t H) { ...@@ -45,36 +45,34 @@ void performTest(const size_t N, const size_t H) {
DType itype = TypeInfo<InputType>::dtype; DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype; DType otype = TypeInfo<OutputType>::dtype;
Tensor input({ N, H }, itype); Tensor input("input", { N, H }, itype);
Tensor output_c({ N, H }, otype); Tensor output("output", { N, H }, otype, true, true);
Tensor output_t({ H, N }, otype);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(N * H); std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<OutputType[]> ref_output_t = std::make_unique<OutputType[]>(N * H); std::unique_ptr<OutputType[]> ref_output_t = std::make_unique<OutputType[]>(N * H);
fillUniform(&input); fillUniform(&input);
setRandomScale(&output_c); setRandomScale(&output);
output_t.shareFP8Meta(output_c);
nvte_cast_transpose(input.data(), output_c.data(), output_t.data(), 0); nvte_quantize(input.data(), output.data(), 0);
float ref_amax; float ref_amax;
compute_ref<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output_c.get(), compute_ref<InputType, OutputType>(input.rowwise_cpu_dptr<InputType>(), ref_output_c.get(),
ref_output_t.get(), N, H, &ref_amax, ref_output_t.get(), N, H, &ref_amax,
output_c.scale()); output.scale());
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto err = cudaGetLastError(); auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) { if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale(); float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); compareResults("output_c", output, ref_output_c.get(), true, atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); compareResults("output_t", output, ref_output_t.get(), false, atol, rtol);
} }
std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288}, std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment