Commit 996ea169 authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Inital code drop


Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parents
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5
pytest -v -s $TE_PATH/tests/test_transformerengine.py
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import atexit
import os
import sys
import subprocess
import io
import re
import copy
import tempfile
from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext
from distutils.version import LooseVersion
from distutils.file_util import copy_file
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
path = os.path.dirname(os.path.realpath(__file__))
with open(path + "/VERSION", "r") as f:
te_version = f.readline()
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
def extra_gencodes(cc_flag):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if int(bare_metal_minor) >= 8:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
def extra_compiler_flags():
return [
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"-I./transformer_engine/common/layer_norm/",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
cc_flag = []
extra_gencodes(cc_flag)
def make_abs_path(l):
return [os.path.join(path, p) for p in l]
include_dirs = [
"transformer_engine/common/include",
"transformer_engine/pytorch/csrc",
]
include_dirs = make_abs_path(include_dirs)
pytorch_sources = [
"transformer_engine/pytorch/csrc/extensions.cu",
"transformer_engine/pytorch/csrc/common.cu",
]
pytorch_sources = make_abs_path(pytorch_sources)
all_sources = pytorch_sources
supported_frameworks = {
"all": all_sources,
"pytorch": pytorch_sources,
}
framework = "all"
args = sys.argv.copy()
for s in args:
if s.startswith("--framework="):
framework = s.replace("--framework=", "")
sys.argv.remove(s)
if framework not in supported_frameworks.keys():
raise ValueError("Unsupported framework " + framework)
class CMakeExtension(Extension):
def __init__(self, name, cmake_path, sources, **kwargs):
super(CMakeExtension, self).__init__(name, sources=sources, **kwargs)
self.cmake_path = cmake_path
ext_modules = []
ext_modules.append(
CMakeExtension(
name="transformer_engine",
cmake_path=os.path.join(path, "transformer_engine/common"),
sources=[],
include_dirs=include_dirs,
)
)
if framework in ("all", "pytorch"):
ext_modules.append(
CUDAExtension(
name="transformer_engine_extensions",
sources=supported_frameworks[framework],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(extra_compiler_flags() + cc_flag),
},
include_dirs=include_dirs,
)
)
ext_modules.append(
CUDAExtension(
name="scaled_upper_triang_masked_softmax_cuda",
sources=[
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_upper_triang_masked_softmax.cpp",
),
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu",
),
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(extra_compiler_flags() + cc_flag),
},
include_dirs=[
os.path.join(path, "transformer_engine/pytorch/csrc/fused_softmax")
],
)
)
ext_modules.append(
CUDAExtension(
name="scaled_masked_softmax_cuda",
sources=[
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_masked_softmax.cpp",
),
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_masked_softmax_cuda.cu",
),
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(extra_compiler_flags() + cc_flag),
},
include_dirs=[
os.path.join(path, "transformer_engine/pytorch/csrc/fused_softmax")
],
)
)
ext_modules.append(
CUDAExtension(
name="scaled_softmax_cuda",
sources=[
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_softmax.cpp",
),
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_softmax_cuda.cu",
),
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(extra_compiler_flags() + cc_flag),
},
include_dirs=[
os.path.join(path, "transformer_engine/pytorch/csrc/fused_softmax")
],
)
)
def get_cmake_bin():
cmake_bin = "cmake"
try:
out = subprocess.check_output([cmake_bin, "--version"])
except OSError:
cmake_installed_version = LooseVersion("0.0")
else:
cmake_installed_version = LooseVersion(
re.search(r"version\s*([\d.]+)", out.decode()).group(1)
)
if cmake_installed_version < LooseVersion("3.18.0"):
print(
"Could not find a recent CMake to build Transformer Engine. "
"Attempting to install CMake 3.18 to a temporary location via pip.",
flush=True,
)
cmake_temp_dir = tempfile.TemporaryDirectory(prefix="nvte-cmake-tmp")
atexit.register(cmake_temp_dir.cleanup)
try:
_ = subprocess.check_output(
["pip", "install", "--target", cmake_temp_dir.name, "cmake~=3.18.0"]
)
except Exception:
raise RuntimeError(
"Failed to install temporary CMake. "
"Please update your CMake to 3.18+."
)
cmake_bin = os.path.join(cmake_temp_dir.name, "bin", "run_cmake")
with io.open(cmake_bin, "w") as f_run_cmake:
f_run_cmake.write(
f"#!/bin/sh\nPYTHONPATH={cmake_temp_dir.name} {os.path.join(cmake_temp_dir.name, 'bin', 'cmake')} \"$@\""
)
os.chmod(cmake_bin, 0o755)
return cmake_bin
class CMakeBuildExtension(build_ext, object):
def __init__(self, *args, **kwargs) -> None:
super(CMakeBuildExtension, self).__init__(*args, **kwargs)
def build_extensions(self) -> None:
print("Building CMake extensions!")
cmake_bin = get_cmake_bin()
config = "Debug" if self.debug else "Release"
ext_name = self.extensions[0].name
build_dir = self.get_ext_fullpath(ext_name).replace(
self.get_ext_filename(ext_name), ""
)
build_dir = os.path.abspath(build_dir)
cmake_args = [
"-GNinja",
"-DCMAKE_BUILD_TYPE=" + config,
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(config.upper(), build_dir),
]
cmake_build_args = ["--config", config]
cmake_build_dir = os.path.join(self.build_temp, config)
if not os.path.exists(cmake_build_dir):
os.makedirs(cmake_build_dir)
config_and_build_commands = [
[cmake_bin, self.extensions[0].cmake_path] + cmake_args,
[cmake_bin, "--build", "."] + cmake_build_args,
]
if True:
print(f"Running CMake in {cmake_build_dir}:")
for command in config_and_build_commands:
print(" ".join(command))
sys.stdout.flush()
# Config and build the extension
try:
for command in config_and_build_commands:
subprocess.check_call(command, cwd=cmake_build_dir)
except OSError as e:
raise RuntimeError("CMake failed: {}".format(str(e)))
class TEBuildExtension(build_ext, object):
def __init__(self, *args, **kwargs) -> None:
cmake_args = copy.deepcopy(args)
cmake_kwargs = copy.deepcopy(kwargs)
pytorch_args = copy.deepcopy(args)
pytorch_kwargs = copy.deepcopy(kwargs)
self.cmake_build_extensions = CMakeBuildExtension(*cmake_args, **cmake_kwargs)
self.pytorch_build_extensions = BuildExtension(*pytorch_args, **pytorch_kwargs)
self.all_outputs = None
super(TEBuildExtension, self).__init__(*args, **kwargs)
def initialize_options(self):
self.cmake_build_extensions.initialize_options()
self.pytorch_build_extensions.initialize_options()
super(TEBuildExtension, self).initialize_options()
def finalize_options(self):
self.cmake_build_extensions.finalize_options()
self.pytorch_build_extensions.finalize_options()
super(TEBuildExtension, self).finalize_options()
def run(self) -> None:
old_inplace, self.inplace = self.inplace, 0
cmake_ext = [ext for ext in self.extensions if isinstance(ext, CMakeExtension)]
self.cmake_build_extensions.extensions = cmake_ext
self.cmake_build_extensions.run()
other_ext = [
ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
]
self.pytorch_build_extensions.extensions = other_ext
print("Building pyTorch extensions!")
self.pytorch_build_extensions.run()
self.all_outputs = []
for f in os.scandir(self.build_lib):
if f.is_file():
self.all_outputs.append(f.path)
self.inplace = old_inplace
if old_inplace:
self.copy_extensions_to_source()
def copy_extensions_to_source(self):
ext = self.extensions[0]
build_py = self.get_finalized_command("build_py")
fullname = self.get_ext_fullname(ext.name)
modpath = fullname.split(".")
package = ".".join(modpath[:-1])
package_dir = build_py.get_package_dir(package)
for f in os.scandir(self.build_lib):
if f.is_file():
src_filename = f.path
dest_filename = os.path.join(
package_dir, os.path.basename(src_filename)
)
# Always copy, even if source is older than destination, to ensure
# that the right extensions for the current Python/platform are
# used.
copy_file(
src_filename,
dest_filename,
verbose=self.verbose,
dry_run=self.dry_run,
)
def get_outputs(self):
return self.all_outputs
setup(
name="transformer_engine",
version=te_version,
packages=find_packages(
exclude=(
"build",
"csrc",
"include",
"tests",
"dist",
"docs",
"tests",
"examples",
"transformer_engine.egg-info",
)
),
description="Transformer acceleration library",
ext_modules=ext_modules,
setup_requires=["pytest-runner"],
cmdclass={"build_ext": TEBuildExtension},
license_files=("LICENSE",),
)
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 90)
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
project(transformer_engine_tests LANGUAGES CUDA CXX)
add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest)
enable_testing()
include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
if(NOT DEFINED TE_LIB_PATH)
execute_process(COMMAND bash -c "pip show transformer-engine | grep Location | cut -d ' ' -f 2 | tr -d '\n'"
OUTPUT_VARIABLE TE_LIB_PATH)
endif()
find_library(TE_LIB NAMES transformer_engine PATHS ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED)
message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include)
include_directories(${CMAKE_SOURCE_DIR})
find_package(CUDAToolkit REQUIRED)
add_subdirectory(operator)
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
add_executable(test_operator test_qdq.cu
test_cast_transpose.cu
test_transpose.cu
test_cast_transpose_dbias.cu
test_cast_transpose_dbias_dgelu.cu
test_gelu.cu
test_layernorm.cu
../test_common.cu)
target_link_libraries(test_operator PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB})
include(GoogleTest)
gtest_discover_tests(test_operator)
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transpose.h>
#include <transformer_engine/logging.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory>
#include <iostream>
#include <iomanip>
#include <random>
#include <cstring>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref(const InputType *data, OutputType *output_c, OutputType *output_t,
const size_t N, const size_t H,
float *amax, float scale) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
current_max = fmaxf(current_max, fabsf(current));
output_c[i * H + j] = OutputType(scale * current);
output_t[j * N + i] = OutputType(scale * current);
}
}
*amax = current_max;
}
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H) {
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input({ N, H }, itype);
Tensor output_c({ N, H }, otype);
Tensor output_t({ H, N }, otype);
Tensor scale({ 1 }, DType::kFloat32);
Tensor amax({ 1 }, DType::kFloat32);
Tensor scale_inv({ 1 }, DType::kFloat32);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<OutputType[]> ref_output_t = std::make_unique<OutputType[]>(N * H);
fillUniform(input);
fillUniform(scale);
nvte_cast_transpose(input.data(), scale.data(), output_c.data(), output_t.data(),
amax.data(), scale_inv.data(), 0);
float ref_amax;
compute_ref<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output_c.get(),
ref_output_t.get(), N, H, &ref_amax,
*(scale.cpu_dptr<float>()));
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
{768, 1024},
{256, 65536},
{65536, 128},
{256, 256},
{120, 2080},
{8, 8}};
} // namespace
class CTTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>>> {};
TEST_P(CTTestSuite, TestCastTranspose) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CTTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CTTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transpose.h>
#include <transformer_engine/logging.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory>
#include <iostream>
#include <iomanip>
#include <random>
#include <cstring>
#include <cmath>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename IT, typename OT, typename CT>
void compute_ref_cast_transpose_dbias(const IT *input_h,
const CT *scale_h,
OT *output_c_h,
OT *output_t_h,
CT *amax_h,
IT *dbias_h,
const size_t N,
const size_t H) {
CT amax = 0.;
CT scale = *scale_h;
std::vector<CT> acc_dbias(H, 0.);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]);
// update amax
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
output_c_h[i * H + j] = static_cast<OT>(scale * elt);
output_t_h[j * N + i] = static_cast<OT>(scale * elt);
// dbias
acc_dbias[j] += elt;
}
}
*amax_h = amax;
for (size_t i = 0; i < H; i++) {
dbias_h[i] = static_cast<IT>(acc_dbias[i]);
}
}
template <typename IType, typename OType>
void performTest(const size_t N, const size_t H) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
DType ctype = TypeInfo<CType>::dtype;
Tensor input({N, H}, itype);
Tensor scale({1}, ctype);
Tensor output_c({N, H}, otype);
Tensor output_t({ H, N}, otype);
Tensor amax({1}, ctype);
Tensor scale_inv({1}, ctype);
// dbias has the same data type with "output grad"
Tensor dbias({H}, itype);
fillUniform(input);
fillUniform(scale);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_output_dbias = std::make_unique<IType[]>(H);
CType ref_amax;
compute_ref_cast_transpose_dbias(input.cpu_dptr<IType>(),
scale.cpu_dptr<CType>(),
ref_output_c.get(),
ref_output_t.get(),
&ref_amax,
ref_output_dbias.get(),
N, H);
Tensor workspace;
nvte_cast_transpose_dbias(input.data(),
scale.data(),
output_c.data(),
output_t.data(),
amax.data(),
dbias.data(),
scale_inv.data(),
workspace.data(),
0);
workspace = Tensor(workspace.shape(), workspace.dtype());
nvte_cast_transpose_dbias(input.data(),
scale.data(),
output_c.data(),
output_t.data(),
amax.data(),
dbias.data(),
scale_inv.data(),
workspace.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>());
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
rtol_dbias *= 4;
compareResults("output_dbias", dbias, ref_output_dbias.get(), atol_dbias, rtol_dbias);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{64, 400},
{2048, 12288},
{768, 1024},
{256, 65536},
{65536, 128},
{256, 256}};
} // namespace;
class CTDBiasTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>>> {};
TEST_P(CTDBiasTestSuite, TestCTDBias) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CTDBiasTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CTDBiasTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transpose.h>
#include <transformer_engine/logging.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory>
#include <iostream>
#include <iomanip>
#include <random>
#include <cstring>
#include <cmath>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename CType>
CType dgelu(const CType cval) {
const CType tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) *
(0.79788456f + 0.1070322243f * cval * cval)) +
0.5f * (1.f + tanh_out);
}
template <typename IT, typename OT, typename CT>
void compute_ref_cast_transpose_dbias_dgelu(const IT *input,
const IT *gelu_input,
const CT *scale_h,
OT *output_c,
OT *output_t,
CT *amax_h,
IT *dbias,
const size_t N,
const size_t H) {
CT amax = 0.;
CT scale = *scale_h;
std::vector<CT> acc_dbias(H, 0.);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input[i * H + j]);
const CT gelu_in = static_cast<CT>(gelu_input[i * H + j]);
elt = dgelu(gelu_in) * elt;
// update amax
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
output_c[i * H + j] = static_cast<OT>(scale * elt);
output_t[j * N + i] = static_cast<OT>(scale * elt);
// dbias
acc_dbias[j] += elt;
}
}
*amax_h = amax;
for (size_t i = 0; i < H; i++) {
dbias[i] = static_cast<IT>(acc_dbias[i]);
}
}
template <typename IType, typename OType>
void performTest(const size_t N, const size_t H) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
DType ctype = TypeInfo<CType>::dtype;
Tensor input({N, H}, itype);
Tensor gelu_input({N, H}, itype);
Tensor scale({1}, ctype);
Tensor output_c({N, H}, otype);
Tensor output_t({ H, N}, otype);
Tensor amax({1}, ctype);
Tensor scale_inv({1}, ctype);
// dbias has the same data type with "output grad"
Tensor dbias({H}, itype);
fillUniform(input);
fillUniform(gelu_input);
fillUniform(scale);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_output_dbias = std::make_unique<IType[]>(H);
CType ref_amax;
compute_ref_cast_transpose_dbias_dgelu(input.cpu_dptr<IType>(),
gelu_input.cpu_dptr<IType>(),
scale.cpu_dptr<CType>(),
ref_output_c.get(),
ref_output_t.get(),
&ref_amax,
ref_output_dbias.get(),
N, H);
Tensor workspace;
nvte_cast_transpose_dbias_dgelu(input.data(),
gelu_input.data(),
scale.data(),
output_c.data(),
output_t.data(),
amax.data(),
dbias.data(),
scale_inv.data(),
workspace.data(),
0);
workspace = Tensor(workspace.shape(), workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input.data(),
gelu_input.data(),
scale.data(),
output_c.data(),
output_t.data(),
amax.data(),
dbias.data(),
scale_inv.data(),
workspace.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>());
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
rtol_dbias *= 4;
compareResults("output_dbias", dbias, ref_output_dbias.get(), atol_dbias, rtol_dbias);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{64, 400},
{2048, 12288},
{768, 1024},
{256, 65536},
{65536, 128},
{256, 256}};
} // namespace;
class CTDBiasDGeluTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t,
size_t>>> {};
TEST_P(CTDBiasDGeluTestSuite, TestCTDBiasDgelu) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CTDBiasDGeluTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CTDBiasDGeluTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/activation.h>
#include <transformer_engine/logging.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cmath>
#include <memory>
#include <iostream>
#include <iomanip>
#include <random>
#include <cstring>
#include <type_traits>
#include "../test_common.h"
using namespace transformer_engine;
template <typename IT, typename OT, typename CT>
void compute_ref_gelu_cast(const IT *input_h,
OT *output_h,
const CT *scale_h,
CT *amax_h,
const size_t N,
const size_t H) {
CT amax = 0.;
CT scale = 1;
if (std::is_same<OT, test::fp8e4m3>::value ||
std::is_same<OT, test::fp8e5m2>::value) {
scale = *scale_h;
}
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = CT(input_h[i * H + j]);
elt = 0.5f * elt * (1.0f + tanhf(0.79788456F * elt *
(1.0f + 0.044715f * elt * elt)));
output_h[i * H + j] = OT(scale * elt);
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
}
}
*amax_h = amax;
}
template <typename IType, typename OType>
void performTestGelu(const size_t N, const size_t H) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
DType ctype = TypeInfo<CType>::dtype;
Tensor input({ N, H }, itype);
Tensor output({ N, H }, otype);
Tensor scale({ 1 }, ctype);
Tensor amax({ 1 }, ctype);
Tensor scale_inv({ 1 }, ctype);
fillUniform(input);
fillUniform(scale);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N*H);
nvte_gelu(input.data(), output.data(), scale.data(),
amax.data(), scale_inv.data(), 0);
float ref_amax;
compute_ref_gelu_cast(input.cpu_dptr<IType>(), ref_output.get(),
scale.cpu_dptr<float>(),
&ref_amax, N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>());
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol);
}
class GELUTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>>> {};
TEST_P(GELUTestSuite, TestGELU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTestGelu<InputType, OutputType>(size.first, size.second);
);
);
}
namespace {
std::vector<std::pair<size_t, size_t>> gelu_test_cases = {{2048, 12288},
{768, 1024},
{256, 65536},
{65536, 128},
{256, 256},
{257, 259},
{128, 128+1}};
} // namespace
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
GELUTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(gelu_test_cases)),
[](const testing::TestParamInfo<GELUTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/transformer_engine.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory>
#include <iostream>
#include <iomanip>
#include <random>
#include <cstring>
#include <cmath>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
template <typename InputType>
void compute_ref_stats(const InputType *data, float *mu, float *rsigma,
const size_t N, const size_t H, const double epsilon) {
using compute_t = float;
for (size_t i = 0 ; i < N; ++i) {
compute_t sum = 0;
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
sum += current;
}
mu[i] = sum / H;
compute_t m = mu[i];
sum = 0;
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
sum += (current - m) * (current - m);
}
sum = sum / H;
compute_t rs = rsqrtf(sum + epsilon);
rsigma[i] = rs;
}
}
template <typename InputType, typename OutputType>
void compute_ref_output(const InputType *data, const InputType *gamma, const InputType *beta,
OutputType *output, const float *mu, const float *rsigma,
const size_t N, const size_t H,
float *amax, float scale) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0 ; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t tmp = (current - mu[i]) * rsigma[i] * static_cast<compute_t>(gamma[j]) +
static_cast<compute_t>(beta[j]);
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
}
}
*amax = current_max;
}
template <typename InputType, typename OutputType>
void compute_ref_backward(const OutputType *output_grad, const InputType *data,
const float *mu, const float *rsigma,
const InputType *gamma,
InputType *data_grad,
InputType *gamma_grad, InputType *beta_grad,
const size_t N, const size_t H) {
using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f);
std::vector<compute_t> dbeta(H, 0.f);
for (size_t i = 0 ; i < N; ++i) {
// Reductions
compute_t mdy = 0, mdyy = 0;
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - mu[i]) * rsigma[i];
const compute_t g = static_cast<compute_t>(gamma[j]);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
dgamma[j] += y * dz;
dbeta[j] += dz;
mdy += dy;
mdyy += dy * y;
}
mdy /= H;
mdyy /= H;
// Input grads
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - mu[i]) * rsigma[i];
const compute_t g = static_cast<compute_t>(gamma[j]);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy);
data_grad[i * H + j] = static_cast<InputType>(dx);
}
}
// Weight grads
for (size_t j = 0; j < H; ++j) {
gamma_grad[j] = static_cast<InputType>(dgamma[j]);
beta_grad[j] = static_cast<InputType>(dbeta[j]);
}
}
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H) {
if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "LN kernel does not support OutputType > InputType";
return;
}
using WeightType = InputType;
DType itype = TypeInfo<InputType>::dtype;
DType wtype = TypeInfo<WeightType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
if ((itype == DType::kBFloat16 && otype == DType::kFloat16) ||
(itype == DType::kFloat16 && otype == DType::kBFloat16)) {
GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16";
return;
}
Tensor input({ N, H }, itype);
Tensor z({ N, H }, otype);
Tensor gamma({ H }, wtype);
Tensor beta({ H }, wtype);
Tensor scale({ 1 }, DType::kFloat32);
Tensor amax({ 1 }, DType::kFloat32);
Tensor scale_inv({ 1 }, DType::kFloat32);
Tensor mu({ N }, DType::kFloat32);
Tensor rsigma({ N }, DType::kFloat32);
Tensor dz({ N, H }, wtype);
Tensor dx({ N, H }, itype);
Tensor dgamma({ H }, wtype);
Tensor dbeta({ H }, wtype);
Tensor workspace, barrier, dgamma_part, dbeta_part;
fillUniform(input);
fillUniform(gamma);
fillUniform(beta);
fillUniform(scale);
fillUniform(dz);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<float[]> ref_mu = std::make_unique<float[]>(N);
std::unique_ptr<float[]> ref_rsigma = std::make_unique<float[]>(N);
std::unique_ptr<InputType[]> ref_dx = std::make_unique<InputType[]>(N * H);
std::unique_ptr<WeightType[]> ref_dgamma = std::make_unique<InputType[]>(H);
std::unique_ptr<WeightType[]> ref_dbeta = std::make_unique<InputType[]>(H);
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
// Forward kernel
float epsilon = 1e-5;
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), scale.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data(), amax.data(), scale_inv.data(), true);
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), scale.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data(), amax.data(), scale_inv.data(), true);
// Backward kernel
nvte_layernorm_bwd(dz.data(), input.data(),
mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(),
dgamma_part.data(), dbeta_part.data(),
0, prop.multiProcessorCount,
workspace.data(), barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype());
dbeta_part = Tensor(dbeta_part.shape(), dbeta_part.dtype());
nvte_layernorm_bwd(dz.data(), input.data(),
mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(),
dgamma_part.data(), dbeta_part.data(),
0, prop.multiProcessorCount,
workspace.data(), barrier.data());
// Reference implementations
// use the GPU stats to tighten the tolerances
mu.to_cpu();
rsigma.to_cpu();
float ref_amax;
compute_ref_stats(input.cpu_dptr<InputType>(), ref_mu.get(),
ref_rsigma.get(), N, H, epsilon);
compute_ref_output(input.cpu_dptr<InputType>(),
gamma.cpu_dptr<WeightType>(),
beta.cpu_dptr<WeightType>(),
ref_output.get(),
mu.cpu_dptr<float>(),
rsigma.cpu_dptr<float>(),
N, H,
&ref_amax,
*(scale.cpu_dptr<float>()));
compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(),
mu.cpu_dptr<float>(), rsigma.cpu_dptr<float>(),
gamma.cpu_dptr<WeightType>(),
ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(),
N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>());
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax);
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
rtol_stats = 5e-5;
compareResults("mu", mu, ref_mu.get(), atol_stats, rtol_stats);
compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats);
auto [atol, rtol] = getTolerances(otype);
if (otype == DType::kFloat32) {
atol = 5e-7;
}
compareResults("output", z, ref_output.get(), atol, rtol);
double atol_bwd = 1e-4;
double rtol_bwd = 1e-4;
compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd);
compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd);
compareResults("dbeta", dbeta, ref_dbeta.get(), atol_bwd, rtol_bwd);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
{768, 1024},
{256, 65536},
{128, 6144},
{64, 2304},
{229, 541}, // Primes 50, 100
{71, 3571}, // Primes 20, 500
{29, 17389}}; // Primes 10, 2000
} // namespace
class LNTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>>> {};
TEST_P(LNTestSuite, TestLN) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
LNTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<LNTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "gtest/gtest.h"
#include <transformer_engine/cast.h>
#include <transformer_engine/logging.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory>
#include <iostream>
#include <iomanip>
#include <random>
#include <cstring>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_ref_q(const InputType *data, OutputType *output,
const size_t N,
float *amax, float scale) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
compute_t current = static_cast<compute_t>(data[i]);
current_max = fmaxf(current_max, fabsf(current));
if (std::is_same<OutputType, test::fp8e4m3>::value ||
std::is_same<OutputType, test::fp8e5m2>::value) {
output[i] = OutputType(scale * current);
} else {
output[i] = OutputType(current);
}
}
*amax = current_max;
}
template <typename InputType, typename OutputType>
void compute_ref_dq(const InputType *data, OutputType *output,
const size_t N, float scale_inv) {
using compute_t = float;
for (size_t i = 0; i < N; ++i) {
compute_t current = static_cast<compute_t>(data[i]);
output[i] = OutputType(scale_inv * current);
}
}
template <typename InputType, typename OutputType>
void performTestQ(const size_t N) {
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input({ N }, itype);
Tensor output({ N }, otype);
Tensor scale({ 1 }, DType::kFloat32);
Tensor amax({ 1 }, DType::kFloat32);
Tensor scale_inv({ 1 }, DType::kFloat32);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N);
fillUniform(input);
fillUniform(scale);
nvte_fp8_quantize(input.data(), scale.data(), output.data(), amax.data(), scale_inv.data(), 0);
float ref_amax;
compute_ref_q<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output.get(),
N, &ref_amax, *(scale.cpu_dptr<float>()));
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>());
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_q", output, ref_output.get(), atol, rtol);
}
template <typename InputType, typename OutputType>
void performTestDQ(const size_t N) {
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input({ N }, itype);
Tensor output({ N }, otype);
Tensor scale_inv({ 1 }, DType::kFloat32);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N);
fillUniform(input);
fillUniform(scale_inv);
nvte_fp8_dequantize(input.data(), scale_inv.data(), output.data(), 0);
compute_ref_dq<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output.get(),
N, *(scale_inv.cpu_dptr<float>()));
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_dq", output, ref_output.get(), atol, rtol);
}
std::vector<size_t> qdq_test_cases = {2048* 12288,
768 * 1024,
256 * 65536,
65536 * 128,
257 * 259,
128*128+1};
} //namespace
class QDQTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
size_t>> {};
TEST_P(QDQTestSuite, TestQ) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const size_t N = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTestQ<InputType, OutputType>(N);
);
);
}
TEST_P(QDQTestSuite, TestDQ) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const size_t N = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTestDQ<OutputType, InputType>(N);
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
QDQTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(qdq_test_cases)),
[](const testing::TestParamInfo<QDQTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param));
return name;
});
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transpose.h>
#include <transformer_engine/logging.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory>
#include <iostream>
#include <iomanip>
#include <random>
#include <cstring>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename Type>
void compute_ref(const Type *data, Type *output,
const size_t N, const size_t H) {
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
output[j * N + i] = data[i * H + j];
}
}
}
template <typename Type>
void performTest(const size_t N, const size_t H) {
using namespace test;
DType dtype = TypeInfo<Type>::dtype;
Tensor input({ N, H }, dtype);
Tensor output({ H, N }, dtype);
std::unique_ptr<Type[]> ref_output = std::make_unique<Type[]>(N * H);
fillUniform(input);
nvte_transpose(input.data(), output.data(), 0);
compute_ref<Type>(input.cpu_dptr<Type>(), ref_output.get(), N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(dtype);
compareResults("output", output, ref_output.get(), atol, rtol);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
{768, 1024},
{256, 65536},
{65536, 128},
{256, 256},
{120, 2080},
{8, 8}};
} // namespace
class TTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
std::pair<size_t, size_t>>> {};
TEST_P(TTestSuite, TestTranspose) {
using namespace transformer_engine;
using namespace test;
const DType type = std::get<0>(GetParam());
const auto size = std::get<1>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
performTest<T>(size.first, size.second);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
TTestSuite,
::testing::Combine(
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<TTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
std::to_string(std::get<1>(info.param).first) + "X" +
std::to_string(std::get<1>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "test_common.h"
#include "transformer_engine/logging.h"
#include "transformer_engine/transformer_engine.h"
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <random>
namespace test {
std::vector<DType> all_fp_types = {DType::kFloat32,
DType::kFloat16,
DType::kBFloat16,
DType::kFloat8E5M2,
DType::kFloat8E4M3};
bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
if (s1.ndim != s2.ndim) return false;
for (size_t i = 0; i < s1.ndim; ++i) {
if (s1.data[i] != s2.data[i]) return false;
}
return true;
}
size_t typeToSize(DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
{
return TypeInfo<T>::size;
});
}
const std::string &typeName(DType type) {
static const std::unordered_map<DType, std::string> name_map = {
{DType::kByte, "byte"},
{DType::kInt32, "int32"},
{DType::kFloat32, "float32"},
{DType::kFloat16, "float16"},
{DType::kBFloat16, "bfloat16"},
{DType::kFloat8E4M3, "float8e4m3"},
{DType::kFloat8E5M2, "float8e5m2"}};
return name_map.at(type);
}
size_t product(const NVTEShape &shape) {
size_t ret = 1;
for (size_t i = 0; i < shape.ndim; ++i) {
ret *= shape.data[i];
}
return ret;
}
Tensor::Tensor(const NVTEShape &shape, const DType type) {
size_t s = typeToSize(type);
size_t total_size = product(shape) * s;
void *dptr = nullptr;
cpu_data_ = nullptr;
if (total_size != 0) {
cudaMalloc((void**)&dptr, total_size); // NOLINT(*)
cudaMemset(dptr, 0, total_size);
cpu_data_ = std::make_unique<unsigned char[]>(total_size);
}
tensor_ = TensorWrapper(dptr, shape, type);
}
void Tensor::to_cpu() const {
const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype());
cudaMemcpy(cpu_data_.get(), tensor_.dptr(), size, cudaMemcpyDeviceToHost);
}
void Tensor::from_cpu() const {
const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype());
cudaMemcpy(tensor_.dptr(), cpu_data_.get(), size, cudaMemcpyHostToDevice);
}
using std::to_string;
template <typename T>
std::string to_string(const std::vector<T> &v) {
std::string s = "[";
for (const auto x : v) {
s += to_string(x) + ", ";
}
s.pop_back();
s.pop_back();
return s + "]";
}
std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
std::vector<size_t> ret;
size_t current_i = i;
for (size_t current = shape.ndim - 1;
current > 0;
--current) {
ret.push_back(current_i % shape.data[current]);
current_i /= shape.data[current];
}
ret.push_back(current_i);
std::reverse(ret.begin(), ret.end());
return ret;
}
void compareResults(const std::string &name, const Tensor &test, const void *ref,
double atol, double rtol) {
test.to_cpu();
const size_t N = product(test.shape());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
const T *test_data = test.cpu_dptr<T>();
const T *ref_data = reinterpret_cast<const T*>(ref);
for (size_t i = 0; i < N; ++i) {
double t = static_cast<double>(test_data[i]);
double r = static_cast<double>(ref_data[i]);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */
bool assertion = mismatch && test.dtype() == DType::kFloat32;
if (mismatch && !assertion) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
ASSERT_FALSE(assertion) << "Error in tensor " << name << std::endl
<< "Mismatch at place " << to_string(unravel(i, test.shape()))
<< " (" << std::to_string(i) << "): " << t << " vs " << r;
}
);
}
std::pair<double, double> getTolerances(const DType type) {
switch(type) {
case DType::kFloat32:
return {1e-6, 5e-6};
case DType::kFloat16:
return {1e-5, 1e-3};
case DType::kBFloat16:
return {1e-5, 1e-2};
case DType::kFloat8E4M3:
case DType::kFloat8E5M2:
return {1e-2, 1e-2};
default:
NVTE_CHECK("Invalid type!");
}
return {0, 0};
}
void fillUniform(const Tensor &t) {
const size_t size = product(t.shape());
static std::mt19937 gen(12345);
std::uniform_real_distribution<> dis(-2.0, 1.0);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t.dtype(), T, {
T *data = t.cpu_dptr<T>();
for (size_t i = 0; i < size; ++i) {
data[i] = T(dis(gen));
}
});
t.from_cpu();
}
} // namespace test
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <memory>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/logging.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <iostream>
namespace test {
using namespace transformer_engine;
template <size_t i>
struct BytesToType {};
template <>
struct BytesToType<1> {
using Type = uint8_t;
};
template <>
struct BytesToType<2> {
using Type = uint16_t;
};
template <>
struct BytesToType<4> {
using Type = uint32_t;
};
template <>
struct BytesToType<8> {
using Type = uint64_t;
};
using byte = uint8_t;
using int32 = int32_t;
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
template <typename T>
struct TypeInfo{
using types = std::tuple<byte,
int32,
fp32,
fp16,
bf16,
fp8e4m3,
fp8e5m2>;
template <typename U, DType current>
struct Helper {
constexpr static DType getType() {
constexpr int i = static_cast<int>(current);
if (std::is_same<U, typename std::tuple_element<i, types>::type>::value) {
return current;
} else {
return Helper<U, static_cast<DType>(i + 1)>::getType();
}
}
};
template <typename U>
struct Helper<U, DType::kNumTypes> {
constexpr static DType getType() {
return DType::kNumTypes;
}
};
template <typename U>
constexpr static DType getType() {
return Helper<U, DType::kByte>::getType();
}
constexpr static DType dtype = getType<T>();
constexpr static size_t size = sizeof(T);
};
class Tensor {
public:
Tensor(const NVTEShape &shape, const DType type);
Tensor(const std::vector<size_t> &shape, const DType type) :
Tensor(NVTEShape{shape.data(), shape.size()}, type) {}
Tensor() {}
Tensor& operator=(const Tensor &other) = delete;
Tensor(const Tensor &other) = delete;
Tensor(Tensor &&other) = default;
Tensor& operator=(Tensor &&other) = default;
~Tensor() {
if (tensor_.dptr() != nullptr) {
cudaFree(tensor_.dptr());
}
}
NVTETensor data() const noexcept {
return tensor_.data();
}
const NVTEShape shape() const noexcept {
return tensor_.shape();
}
DType dtype() const noexcept {
return tensor_.dtype();
}
void *dptr() const noexcept {
return tensor_.dptr();
}
template <typename T>
T *cpu_dptr() const {
NVTE_CHECK(TypeInfo<T>::dtype == tensor_.dtype(), "Invalid type!");
return reinterpret_cast<T *>(cpu_data_.get());
}
void to_cpu() const;
void from_cpu() const;
private:
TensorWrapper tensor_;
std::unique_ptr<unsigned char[]> cpu_data_;
};
size_t typeToSize(DType type);
size_t product(const NVTEShape &shape);
bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2);
void compareResults(const std::string &name, const Tensor &test, const void *ref,
double atol = 1e-5, double rtol = 1e-8);
std::pair<double, double> getTolerances(const DType type);
void fillUniform(const Tensor &t);
constexpr int THREADS_PER_WARP = 32;
const std::string &typeName(DType type);
extern std::vector<DType> all_fp_types;
} // namespace test
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kByte: \
{ \
using type = byte; \
{__VA_ARGS__} \
} \
break; \
case DType::kInt32: \
{ \
using type = int32; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat32: \
{ \
using type = float; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat16: \
{ \
using type = fp16; \
{__VA_ARGS__} \
} \
break; \
case DType::kBFloat16: \
{ \
using type = bf16; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat8E4M3: \
{ \
using type = fp8e4m3; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat8E5M2: \
{ \
using type = fp8e5m2; \
{__VA_ARGS__} \
} \
break; \
default: \
NVTE_ERROR("Invalid type."); \
}
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
import pytest
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
)
from transformer_engine.pytorch import (
LayerNormLinear,
Linear,
LayerNormMLP,
TransformerLayer,
)
from transformer_engine.common import recipe
def custom_amax_to_scale(
amax: torch.Tensor,
scale: torch.Tensor,
fp8_max: torch.Tensor,
recipe: recipe.DelayedScaling,
) -> torch.Tensor:
"""Custom func to test recipe."""
sf = fp8_max / amax
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
return sf
def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
"""Custom func to test recipe."""
return torch.min(amax_history, dim=0).values
class ModelConfig:
def __init__(
self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
):
self.hidden_size = hidden_size
self.eps = eps
self.num_attention_heads = num_attention_heads
self.embed = embed
self.num_layers = num_layers
self.seq_len = seq_len
model_configs = {
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}
fp8_recipes = [
recipe.DelayedScaling(0, 1, recipe.Format.E4M3),
recipe.DelayedScaling(0, 1, recipe.Format.HYBRID),
recipe.DelayedScaling(
0, 1, recipe.Format.E4M3, override_linear_precision=(False, False, True)
),
recipe.DelayedScaling(
0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="most_recent"
),
recipe.DelayedScaling(
0, 1, recipe.Format.E4M3, amax_history_len=16, amax_compute_algo="max"
),
recipe.DelayedScaling(
0,
1,
recipe.Format.E4M3,
amax_history_len=16,
amax_compute_algo=custom_amax_compute,
),
recipe.DelayedScaling(
0,
1,
recipe.Format.E4M3,
amax_history_len=16,
scaling_factor_compute_algo=custom_amax_to_scale,
),
]
param_types = [torch.float32, torch.bfloat16, torch.float16]
batch_sizes = [1, 2]
def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
assert te_out.dtype == dtype
loss.backward()
torch.cuda.synchronize()
def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_common(block, bs, dtype, config, fp8_recipe):
te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp)
if isinstance(te_out, tuple):
te_out = te_out[0]
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
block = (
LayerNormLinear(
config.hidden_size,
config.hidden_size * 3,
eps=config.eps,
init_method=init_method,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, fp8_recipe)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_linear(dtype, bs, fp8_recipe, model):
config = model_configs[model]
sigma = 0.023
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
Linear(
config.hidden_size, config.hidden_size, init_method=output_layer_init_method
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, fp8_recipe)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
eps=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, fp8_recipe)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_gpt(dtype, bs, fp8_recipe, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_bert(dtype, bs, fp8_recipe, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_T5(dtype, bs, fp8_recipe, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="decoder",
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
)
.to(dtype=torch.float32)
.cuda()
)
_test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_drop_path(dtype, bs, fp8_recipe, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
drop_path_rate=1.0,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
fuse_qkv_params=True,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe)
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
import pytest
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
)
from transformer_engine.pytorch import (
LayerNormLinear,
Linear,
LayerNormMLP,
TransformerLayer,
)
class ModelConfig:
def __init__(
self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
):
self.hidden_size = hidden_size
self.eps = eps
self.num_attention_heads = num_attention_heads
self.embed = embed
self.num_layers = num_layers
self.seq_len = seq_len
model_configs = {
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}
param_types = [torch.float32, torch.bfloat16, torch.float16]
batch_sizes = [1, 2]
def _test_sanity_e2e_amp(block, bs, dtype, config):
if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
return
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
assert te_out.dtype == dtype
loss.backward()
torch.cuda.synchronize()
def _test_sanity_e2e(block, bs, dtype, config):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_e2e_T5(block, bs, dtype, config):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
1,
1,
config.seq_len,
config.seq_len,
)
)
.cuda()
.bool()
)
te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_common(block, bs, dtype, config):
te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_out = block(te_inp)
if isinstance(te_out, tuple):
te_out = te_out[0]
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_linear(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
block = (
LayerNormLinear(
config.hidden_size,
config.hidden_size * 3,
eps=config.eps,
init_method=init_method,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_linear(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
Linear(
config.hidden_size, config.hidden_size, init_method=output_layer_init_method
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_mlp(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
eps=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_gpt(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_bert(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_T5(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="decoder",
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e_T5(block, bs, dtype, config)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_amp_and_nvfuser(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
)
.to(dtype=torch.float32)
.cuda()
)
_test_sanity_e2e_amp(block, bs, dtype, config)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_drop_path(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
drop_path_rate=1.0,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_fused_qkv_params(dtype, bs, model):
config = model_configs[model]
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
fuse_qkv_params=True,
)
.to(dtype=dtype)
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config)
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Top level package"""
from . import common
from . import pytorch
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 90)
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
project(transformer_engine LANGUAGES CUDA CXX)
list(APPEND CMAKE_CUDA_FLAGS "--threads 4")
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
list(APPEND CMAKE_CUDA_FLAGS "-G")
endif()
add_library(transformer_engine SHARED
transformer_engine.cpp
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
activation/gelu.cu
gemm/cublaslt_gemm.cu
layer_norm/ln_api.cpp
layer_norm/ln_bwd_semi_cuda_kernel.cu
layer_norm/ln_fwd_cuda_kernel.cu
util/cast.cu)
target_include_directories(transformer_engine PUBLIC "${PROJECT_SOURCE_DIR}/include")
find_package(CUDAToolkit REQUIRED cublas)
list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart)
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FW agnostic user-end APIs"""
def get_te_path():
"""Find TE path using pip"""
import os
te_info = (
os.popen("pip show transformer_engine").read().replace("\n", ":").split(":")
)
return te_info[te_info.index("Location") + 1].strip()
def _load_library():
"""Load TE .so"""
import os
import ctypes
import platform
system = platform.system()
if system == "Linux":
extension = "so"
elif system == "Darwin":
extension = "dylib"
elif system == "Windows":
extension = "dll"
else:
raise "Unsupported operating system " + system + "."
lib_name = "libtransformer_engine." + extension
dll_path = get_te_path()
dll_path = os.path.join(dll_path, lib_name)
return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)
_TE_LIB_CTYPES = _load_library()
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/activation.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <iostream>
#include "../utils.cuh"
#include "../common.h"
#include <cstdlib>
#include <../util/vectorized_pointwise.h>
namespace transformer_engine {
namespace detail {
struct GELUParam {};
__device__ inline fp32 gelu(fp32 value, const GELUParam &) {
return value * (0.5F + 0.5F * tanhf(value * (0.79788456F + 0.03567741F * value * value)));
}
}
void gelu_cast(const Tensor &input,
const Tensor &scale,
Tensor *output,
Tensor *amax,
Tensor *scale_inv,
cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.shape == output->shape, "Input and output shapes must match.");
const size_t tot_elts = input.shape[1] * input.shape[0];
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX tensor must have 1 element.");
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX tensor must have Float32 type.");
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale tensor must have 1 element.");
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale tensor must have Float32 type.");
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 },
"scale_inv tensor must have 1 element.");
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "scale_inv tensor must have Float32 type.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
NVTE_CHECK(output->dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(amax->dptr != nullptr, "AMAX tensor is not allocated.");
NVTE_CHECK(scale_inv->dptr != nullptr, "scale_inv tensor is not allocated.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, detail::GELUParam, detail::gelu>(
reinterpret_cast<const IType*>(input.dptr),
reinterpret_cast<OType*>(output->dptr),
reinterpret_cast<const fp32*>(scale.dptr),
reinterpret_cast<fp32*>(scale_inv->dptr),
reinterpret_cast<fp32*>(amax->dptr),
tot_elts,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_gelu(const NVTETensor input,
NVTETensor output,
const NVTETensor scale,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream) {
using namespace transformer_engine;
gelu_cast(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(scale_inv),
stream);
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/logging.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime_api.h>
#include <type_traits>
#include <unordered_map>
#include <functional>
#include <stdexcept>
#include <string>
#include <tuple>
#include <vector>
namespace transformer_engine {
struct Tensor {
void* dptr;
std::vector<size_t> shape;
DType dtype;
Tensor() : dptr(nullptr), shape(), dtype(DType::kFloat32) {}
};
template <typename T>
constexpr T DIVUP(const T &x, const T &y) {
return (((x) + ((y)-1)) / (y));
}
using byte = uint8_t;
using int32 = int32_t;
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
template <typename T>
struct TypeInfo{
using types = std::tuple<byte,
int32,
fp32,
fp16,
bf16,
fp8e4m3,
fp8e5m2>;
template <typename U, DType current>
struct Helper {
constexpr static DType getType() {
constexpr int i = static_cast<int>(current);
if (std::is_same<U, typename std::tuple_element<i, types>::type>::value) {
return current;
} else {
return Helper<U, static_cast<DType>(i + 1)>::getType();
}
}
};
template <typename U>
struct Helper<U, DType::kNumTypes> {
constexpr static DType getType() {
return DType::kNumTypes;
}
};
template <typename U>
constexpr static DType getType() {
return Helper<U, DType::kByte>::getType();
}
constexpr static DType dtype = getType<T>();
constexpr static size_t size = sizeof(T);
};
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kByte: \
{ \
using type = float; \
{__VA_ARGS__} \
} \
break; \
case DType::kInt32: \
{ \
using type = float; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat32: \
{ \
using type = float; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat16: \
{ \
using type = fp16; \
{__VA_ARGS__} \
} \
break; \
case DType::kBFloat16: \
{ \
using type = bf16; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat8E4M3: \
{ \
using type = fp8e4m3; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat8E5M2: \
{ \
using type = fp8e5m2; \
{__VA_ARGS__} \
} \
break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: \
{ \
using type = float; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat16: \
{ \
using type = fp16; \
{__VA_ARGS__} \
} \
break; \
case DType::kBFloat16: \
{ \
using type = bf16; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat8E5M2: \
{ \
using type = fp8e5m2; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat8E4M3: \
{ \
using type = fp8e4m3; \
{__VA_ARGS__} \
} \
break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat8E5M2: \
{ \
using type = fp8e5m2; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat8E4M3: \
{ \
using type = fp8e4m3; \
{__VA_ARGS__} \
} \
break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: \
{ \
using type = float; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat16: \
{ \
using type = fp16; \
{__VA_ARGS__} \
} \
break; \
case DType::kBFloat16: \
{ \
using type = bf16; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat8E5M2: \
case DType::kFloat8E4M3: \
{ \
NVTE_ERROR("FP8 type not instantiated for input."); \
} \
break; \
default: \
NVTE_ERROR("Invalid type."); \
}
template<typename T>
struct TypeId{};
template<>
struct TypeId<fp16>{
constexpr static uint32_t Value = 0;
};
template<>
struct TypeId<bf16>{
constexpr static uint32_t Value = 1;
};
template<>
struct TypeId<fp32>{
constexpr static uint32_t Value = 2;
};
template<>
struct TypeId<fp8e4m3>{
constexpr static uint32_t Value = 3;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int S>
struct Type2Key{
constexpr static uint32_t Value = TypeId<T>::Value << S;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct WeightType2Key : public Type2Key<T, 0>{};
template<typename T>
struct InputType2Key : public Type2Key<T, 2>{};
template<typename T>
struct OutputType2Key : public Type2Key<T, 4>{};
template<typename T>
struct ComputeType2Key : public Type2Key<T, 6>{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C>
struct Types2Key{
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value |
OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size){
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
};
inline size_t product(const std::vector<size_t> &shape) {
size_t ret = 1;
for (const auto &elem : shape) {
ret *= elem;
}
return ret;
}
template <typename T>
struct is_fp8 : std::false_type {};
template <>
struct is_fp8<fp8e4m3> : std::true_type {};
template <>
struct is_fp8<fp8e5m2> : std::true_type {};
size_t typeToSize(const DType type);
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment