Commit c520cba3 authored by yuguo's avatar yuguo
Browse files

[DCU] Preliminary adaptation

parent 5b6ef054
...@@ -19,6 +19,8 @@ from typing import List, Optional, Type ...@@ -19,6 +19,8 @@ from typing import List, Optional, Type
import setuptools import setuptools
from .utils import ( from .utils import (
rocm_build,
rocm_path,
cmake_bin, cmake_bin,
debug_build_enabled, debug_build_enabled,
found_ninja, found_ninja,
...@@ -155,26 +157,34 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel ...@@ -155,26 +157,34 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
ext.extra_compile_args[target] = [] ext.extra_compile_args[target] = []
# Define new _compile method that redirects to NVCC for .cu and .cuh files. # Define new _compile method that redirects to NVCC for .cu and .cuh files.
# Also redirect .hip files to HIPCC
original_compile_fn = self.compiler._compile original_compile_fn = self.compiler._compile
self.compiler.src_extensions += [".cu", ".cuh"] self.compiler.src_extensions += [".cu", ".cuh", ".hip"]
def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None: def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
# Copy before we make any modifications. # Copy before we make any modifications.
cflags = copy.deepcopy(extra_postargs) cflags = copy.deepcopy(extra_postargs)
original_compiler = self.compiler.compiler_so original_compiler = self.compiler.compiler_so
try: try:
if rocm_build():
_, nvcc_bin = rocm_path()
else:
_, nvcc_bin = cuda_path() _, nvcc_bin = cuda_path()
original_compiler = self.compiler.compiler_so original_compiler = self.compiler.compiler_so
if os.path.splitext(src)[1] in [".cu", ".cuh"]: if os.path.splitext(src)[1] in [".cu", ".cuh", ".hip"]:
self.compiler.set_executable("compiler_so", str(nvcc_bin)) self.compiler.set_executable("compiler_so", str(nvcc_bin))
if isinstance(cflags, dict): if isinstance(cflags, dict):
cflags = cflags["nvcc"] cflags = cflags["nvcc"]
# Add -fPIC if not already specified # Add -fPIC if not already specified
if not any("-fPIC" in flag for flag in cflags): if not any("-fPIC" in flag for flag in cflags):
if rocm_build():
cflags.append("-fPIC")
else:
cflags.extend(["--compiler-options", "'-fPIC'"]) cflags.extend(["--compiler-options", "'-fPIC'"])
if not rocm_build():
# Forward unknown options # Forward unknown options
if not any("--forward-unknown-opts" in flag for flag in cflags): if not any("--forward-unknown-opts" in flag for flag in cflags):
cflags.append("--forward-unknown-opts") cflags.append("--forward-unknown-opts")
......
...@@ -9,6 +9,8 @@ from pathlib import Path ...@@ -9,6 +9,8 @@ from pathlib import Path
import setuptools import setuptools
from .utils import ( from .utils import (
rocm_build,
hipify,
all_files_in_dir, all_files_in_dir,
cuda_archs, cuda_archs,
cuda_version, cuda_version,
...@@ -37,11 +39,27 @@ def setup_pytorch_extension( ...@@ -37,11 +39,27 @@ def setup_pytorch_extension(
csrc_header_files, csrc_header_files,
] ]
if rocm_build():
current_file_path = Path(__file__).parent.resolve()
base_dir = current_file_path.parent
sources = hipify(base_dir, csrc_source_files, sources, include_dirs)
# Compiler flags # Compiler flags
cxx_flags = [ cxx_flags = [
"-O3", "-O3",
"-fvisibility=hidden", "-fvisibility=hidden",
] ]
if rocm_build():
nvcc_flags = [
"-O3",
"-U__HIP_NO_HALF_OPERATORS__",
"-U__HIP_NO_HALF_CONVERSIONS__",
"-U__HIP_NO_BFLOAT16_OPERATORS__",
"-U__HIP_NO_BFLOAT16_CONVERSIONS__",
"-U__HIP_NO_BFLOAT162_OPERATORS__",
"-U__HIP_NO_BFLOAT162_CONVERSIONS__",
]
else:
nvcc_flags = [ nvcc_flags = [
"-O3", "-O3",
"-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_OPERATORS__",
...@@ -54,13 +72,14 @@ def setup_pytorch_extension( ...@@ -54,13 +72,14 @@ def setup_pytorch_extension(
"--expt-extended-lambda", "--expt-extended-lambda",
"--use_fast_math", "--use_fast_math",
] ]
# Version-dependent CUDA options
if rocm_build():
##TODO: Figure out which hipcc version starts to support this parallel compilation
nvcc_flags.extend(["-parallel-jobs=4"])
else:
cuda_architectures = cuda_archs() cuda_architectures = cuda_archs()
if "70" in cuda_architectures: if "70" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"]) nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"])
# Version-dependent CUDA options
try: try:
version = cuda_version() version = cuda_version()
except FileNotFoundError: except FileNotFoundError:
...@@ -80,6 +99,9 @@ def setup_pytorch_extension( ...@@ -80,6 +99,9 @@ def setup_pytorch_extension(
continue # Already handled continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"]) nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
# Libraries
library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert ( assert (
os.getenv("MPI_HOME") is not None os.getenv("MPI_HOME") is not None
...@@ -88,6 +110,8 @@ def setup_pytorch_extension( ...@@ -88,6 +110,8 @@ def setup_pytorch_extension(
include_dirs.append(mpi_path / "include") include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI") cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI") nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs.append(mpi_path / "lib")
libraries.append("mpi")
# Construct PyTorch CUDA extension # Construct PyTorch CUDA extension
sources = [str(path) for path in sources] sources = [str(path) for path in sources]
...@@ -102,4 +126,6 @@ def setup_pytorch_extension( ...@@ -102,4 +126,6 @@ def setup_pytorch_extension(
"cxx": cxx_flags, "cxx": cxx_flags,
"nvcc": nvcc_flags, "nvcc": nvcc_flags,
}, },
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
) )
...@@ -7,6 +7,28 @@ import os ...@@ -7,6 +7,28 @@ import os
from pathlib import Path from pathlib import Path
import subprocess import subprocess
DAS_VERSION="1.6"
def abi_value():
try:
return (
subprocess.check_output("echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI", shell=True)
.decode('ascii')
.strip()[-1]
)
except Exception:
return abiUNKNOWN
def dtk_version_value():
try:
dtk_path=os.getenv('ROCM_PATH')
dtk_version_path = os.path.join(dtk_path, '.info', "version-dev")
with open(dtk_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines()
dtk_version="dtk"+lines[0][:].replace(".", "")
return dtk_version
except Exception:
return UNKNOWN
def te_version() -> str: def te_version() -> str:
"""Transformer Engine version string """Transformer Engine version string
...@@ -33,5 +55,7 @@ def te_version() -> str: ...@@ -33,5 +55,7 @@ def te_version() -> str:
pass pass
else: else:
commit = output.stdout.strip() commit = output.stdout.strip()
version += f"+{commit}" version += "+das"+ DAS_VERSION + f".git{commit}"+ ".abi"+str(abi_value()) + "." +str(dtk_version_value())
else:
version += "+das"+ DAS_VERSION + f".opt1"+ "." +str(dtk_version_value())
return version return version
...@@ -161,6 +161,44 @@ def found_pybind11() -> bool: ...@@ -161,6 +161,44 @@ def found_pybind11() -> bool:
return False return False
@functools.lru_cache(maxsize=None)
def rocm_build() -> bool:
""" ROCm build should be performed if:
- It is configured with NVTE_USE_ROCM=1 env
OR:
- HIP compiler is found and CUDA one is not
"""
if bool(int(os.getenv("NVTE_USE_ROCM", "0"))):
return True
try:
cuda_path()
return False
except FileNotFoundError:
pass
_, hipcc_bin = rocm_path()
return hipcc_bin.is_file()
@functools.lru_cache(maxsize=None)
def rocm_path() -> Tuple[str, str]:
"""ROCm root path and HIPCC binary path as a tuple"""
"""If ROCm installation is not specified, use default /opt/dtk path"""
if os.getenv("ROCM_PATH"):
rocm_home = Path(os.getenv("ROCM_PATH"))
hipcc_bin = rocm_home / "bin" / "hipcc"
if hipcc_bin is None:
hipcc_bin = shutil.which("hipcc")
if hipcc_bin is not None:
hipcc_bin = Path(hipcc_bin)
rocm_home = hipcc_bin.parent.parent
if hipcc_bin is None:
rocm_home = Path("/opt/dtk/")
hipcc_bin = rocm_home / "bin" / "hipcc"
return rocm_home, hipcc_bin
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def cuda_path() -> Tuple[str, str]: def cuda_path() -> Tuple[str, str]:
"""CUDA root path and NVCC binary path as a tuple. """CUDA root path and NVCC binary path as a tuple.
...@@ -228,6 +266,9 @@ def get_frameworks() -> List[str]: ...@@ -228,6 +266,9 @@ def get_frameworks() -> List[str]:
_frameworks.extend(arg.replace("--framework=", "").split(",")) _frameworks.extend(arg.replace("--framework=", "").split(","))
sys.argv.remove(arg) sys.argv.remove(arg)
if rocm_build():
_requested_frameworks = [framework.lower() for framework in _frameworks]
# Detect installed frameworks if not explicitly specified # Detect installed frameworks if not explicitly specified
if not _frameworks: if not _frameworks:
try: try:
...@@ -255,6 +296,28 @@ def get_frameworks() -> List[str]: ...@@ -255,6 +296,28 @@ def get_frameworks() -> List[str]:
if framework not in supported_frameworks: if framework not in supported_frameworks:
raise ValueError(f"Transformer Engine does not support framework={framework}") raise ValueError(f"Transformer Engine does not support framework={framework}")
if rocm_build():
_unsupported_frameworks = []
if "pytorch" in _frameworks:
try:
from torch.utils.cpp_extension import IS_HIP_EXTENSION
except ImportError:
IS_HIP_EXTENSION=False
if not IS_HIP_EXTENSION:
if "pytorch" in _requested_frameworks:
_unsupported_frameworks.append("pytorch")
_frameworks.remove("pytorch")
if "jax" in _frameworks:
if not any(re.match(r'jax-rocm\d+-plugin', d.metadata['Name']) for d in importlib.metadata.distributions()):
try:
import jaxlib.rocm #pre JAX 0.4.30 way
except ImportError:
if "jax" in _requested_frameworks:
_unsupported_frameworks.append("jax")
_frameworks.remove("jax")
if _unsupported_frameworks:
raise ValueError(f"ROCm is not supported by requested frameworks: {_unsupported_frameworks}")
return _frameworks return _frameworks
...@@ -293,6 +356,41 @@ def copy_common_headers( ...@@ -293,6 +356,41 @@ def copy_common_headers(
shutil.copy(path, new_path) shutil.copy(path, new_path)
def hipify(base_dir, src_dir, sources, include_dirs):
hipify_path = base_dir / "3rdparty" / "hipify_torch"
cwd = os.getcwd()
os.chdir(hipify_path)
from hipify_torch.hipify_python import hipify as do_hipify
os.chdir(cwd)
hipify_result = do_hipify(
project_directory=src_dir,
output_directory=src_dir,
includes=["*"],
ignores=["*/amd_detail/*", "*/aotriton/*", "*/ck_fused_attn/*"],
header_include_dirs=include_dirs,
custom_map_list=base_dir / "hipify_custom_map.json",
extra_files=[],
is_pytorch_extension=True,
hipify_extra_files_only=False,
show_detailed=False)
# Because hipify output_directory == project_directory
# Original sources list may contain previous hipifying results that ends up with duplicated entries
# Keep unique entries only
hipified_sources = set()
for fname in sources:
fname = os.path.abspath(str(fname))
if fname in hipify_result:
file_result = hipify_result[fname]
if file_result.hipified_path is not None:
fname = hipify_result[fname].hipified_path
# setup() arguments must *always* be /-separated paths relative to the setup.py directory,
# *never* absolute paths
hipified_sources.add(os.path.relpath(fname, cwd))
return list(hipified_sources)
def install_and_import(package): def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals.""" """Install a package via pip (if not already installed) and import into globals."""
main_package = package.split("[")[0] main_package = package.split("[")[0]
......
{
"custom_map" : {
"<cuda_bf16.h>" : "<hip/hip_bf16.h>",
"<cuda_fp8.h>" : "\"amd_detail/hip_float8.h\"",
"CUfunc_cache" : "hipFuncCache_t",
"<nvtx3/nvToolsExt.h>" : "<roctracer/roctx.h>",
"cudaLaunchKernelExC" : "hipLaunchKernelExC",
"cudaLaunchConfig_t" : "hipLaunchConfig_t",
"cudaLaunchAttributeClusterDimension" : "hipLaunchAttributeClusterDimension",
"cudaLaunchAttributeCooperative" : "hipLaunchAttributeCooperative",
"cudaLaunchAttribute": "hipLaunchAttribute"
}
}
\ No newline at end of file
...@@ -16,6 +16,7 @@ from wheel.bdist_wheel import bdist_wheel ...@@ -16,6 +16,7 @@ from wheel.bdist_wheel import bdist_wheel
from build_tools.build_ext import CMakeExtension, get_build_ext from build_tools.build_ext import CMakeExtension, get_build_ext
from build_tools.te_version import te_version from build_tools.te_version import te_version
from build_tools.utils import ( from build_tools.utils import (
rocm_build,
cuda_archs, cuda_archs,
found_cmake, found_cmake,
found_ninja, found_ninja,
...@@ -57,6 +58,9 @@ class TimedBdist(bdist_wheel): ...@@ -57,6 +58,9 @@ class TimedBdist(bdist_wheel):
def setup_common_extension() -> CMakeExtension: def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library""" """Setup CMake extension for common library"""
if rocm_build():
cmake_flags = []
else:
cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)] cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert ( assert (
...@@ -69,6 +73,11 @@ def setup_common_extension() -> CMakeExtension: ...@@ -69,6 +73,11 @@ def setup_common_extension() -> CMakeExtension:
# Project directory root # Project directory root
root_path = Path(__file__).resolve().parent root_path = Path(__file__).resolve().parent
if rocm_build():
if os.getenv("NVTE_USE_HIPBLASLT") is not None:
cmake_flags.append("-DUSE_HIPBLASLT=ON")
if os.getenv("NVTE_USE_ROCBLAS") is not None:
cmake_flags.append("-DUSE_ROCBLAS=ON")
return CMakeExtension( return CMakeExtension(
name="transformer_engine", name="transformer_engine",
......
...@@ -4,20 +4,59 @@ ...@@ -4,20 +4,59 @@
cmake_minimum_required(VERSION 3.18) cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) option(USE_CUDA "Use CUDA" ON)
option(USE_ROCM "Use ROCm" OFF)
if(((EXISTS "/opt/dtk/") OR (EXISTS $ENV{ROCM_PATH})) AND NOT (EXISTS "/bin/nvcc"))
message("hcu detected.")
set(USE_ROCM ON)
set(USE_CUDA OFF)
# Add HIP to the CMAKE Module Path
# set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH})
# Disable Asserts In Code (Can't use asserts on HIP stack.)
add_definitions(-DNDEBUG)
add_definitions(-DUSE_ROCM)
if(NOT DEFINED ENV{NVTE_ROCM_ARCH})
SET(CMAKE_HIP_ARCHITECTURES gfx906;gfx926;gfx928;gfx936)
else()
SET(CMAKE_HIP_ARCHITECTURES $ENV{NVTE_ROCM_ARCH})
endif()
else()
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
else () else ()
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90)
endif() endif()
endif()
endif() endif()
set(message_line
"-------------------------------------------------------------")
message("${message_line}")
message(STATUS "USE_CUDA ${USE_CUDA}")
message(STATUS "USE_ROCM ${USE_ROCM}")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
project(transformer_engine_tests LANGUAGES CUDA CXX)
if(USE_CUDA)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
project(transformer_engine_tests LANGUAGES CUDA CXX)
else()
set(CMAKE_CXX_STANDARD 17)
project(transformer_engine_tests LANGUAGES HIP CXX)
# Ask hcc to generate device code during compilation so we can use
# host linker to link.
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -fno-gpu-rdc -Wno-defaulted-function-deleted")
foreach(rocm_arch ${CMAKE_HIP_ARCHITECTURES})
# if CMAKE_CXX_FLAGS has --offload-arch set already, better to rm first
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} --offload-arch=${rocm_arch}")
endforeach()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${HIP_HCC_FLAGS}")
endif()
add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest) add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest)
...@@ -37,8 +76,12 @@ include_directories(../../transformer_engine/common/include) ...@@ -37,8 +76,12 @@ include_directories(../../transformer_engine/common/include)
include_directories(../../transformer_engine/common) include_directories(../../transformer_engine/common)
include_directories(${CMAKE_SOURCE_DIR}) include_directories(${CMAKE_SOURCE_DIR})
find_package(CUDAToolkit REQUIRED) if(USE_CUDA)
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) find_package(CUDAToolkit REQUIRED)
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
else()
find_package(hip REQUIRED)
endif()
add_subdirectory(operator) add_subdirectory(operator)
add_subdirectory(util) add_subdirectory(util)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
add_executable(test_operator list(APPEND test_cuda_sources
test_cast.cu test_cast.cu
test_cast_current_scaling.cu test_cast_current_scaling.cu
test_cast_dbias.cu test_cast_dbias.cu
...@@ -26,12 +26,52 @@ add_executable(test_operator ...@@ -26,12 +26,52 @@ add_executable(test_operator
test_causal_softmax.cu test_causal_softmax.cu
test_swizzle.cu test_swizzle.cu
../test_common.cu) ../test_common.cu)
if(USE_ROCM)
list(APPEND test_cuda_sources
test_cublaslt_gemm.cu)
endif()
if(USE_CUDA)
add_executable(test_operator ${test_cuda_sources})
else()
message("${message_line}")
message(STATUS "CMAKE_CURRENT_SOURCE_DIR: ${CMAKE_CURRENT_SOURCE_DIR}")
message(STATUS "PROJECT_SOURCE_DIR: ${PROJECT_SOURCE_DIR}")
set(TE ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
set(THIRDPARTY ${TE}/3rdparty)
list(APPEND CMAKE_MODULE_PATH "${THIRDPARTY}/hipify_torch/cmake")
include(Hipify)
message(STATUS "CMAKE_MODULE_PATH: ${CMAKE_MODULE_PATH}")
file(REAL_PATH ../../../transformer_engine/common/include header_include_dir1)
file(REAL_PATH ../../../transformer_engine/common header_include_dir2)
set(header_include_dir ${header_include_dir1} ${header_include_dir2})
message(STATUS "CUDA_SOURCE_DIR: ${PROJECT_SOURCE_DIR}")
message(STATUS "HEADER_INCLUDE_DIR: ${header_include_dir}")
set(cuda_source_dir ${PROJECT_SOURCE_DIR} )
hipify(CUDA_SOURCE_DIR ${cuda_source_dir}
HEADER_INCLUDE_DIR ${header_include_dir}
CUSTOM_MAP_FILE "${TE}/hipify_custom_map.json"
)
get_hipified_list("${test_cuda_sources}" test_hip_sources)
message("${message_line}")
message(STATUS "nvte tests hipified sources: ${test_hip_sources}")
add_executable(test_operator ${test_hip_sources})
endif()
find_package(OpenMP REQUIRED) find_package(OpenMP REQUIRED)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) if(USE_CUDA)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX) target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX)
target_compile_options(test_operator PRIVATE -O2 -fopenmp) target_compile_options(test_operator PRIVATE -O2 -fopenmp)
else()
target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX)
target_compile_options(test_operator PRIVATE -O2 -fopenmp)
endif()
include(GoogleTest) include(GoogleTest)
gtest_discover_tests(test_operator DISCOVERY_TIMEOUT 600) gtest_discover_tests(test_operator DISCOVERY_TIMEOUT 600)
/*************************************************************************
* Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory>
#include <iostream>
#include <iomanip>
#include <random>
#include <cstring>
#include <cmath>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
//m, k, n
std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = {
{2304, 768, 4096},
{768, 768, 4096},
{768, 3072, 4096},
{229, 541, 541}, //primes
{71, 71, 3571}, //primes
{29, 29, 17389}, //primes
};
// A, B, Bias, Gelu, D
// Bias type choose as bf16 in use_fp8, D_type otherwise
// Gelu type the same as Bias_Type
// {DType::kFloat32, DType::kFloat32, DType::kFloat32, DType::kFloat32, DType::kFloat32},
// {DType::kFloat16, DType::kFloat16, DType::kFloat16, DType::kFloat16, DType::kFloat16},
// {DType::kBFloat16, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16},
// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat32},
// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat16},
// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16},
// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E4M3},
// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E5M2},
// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kFloat32},
// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kFloat16},
// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16},
// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E4M3},
// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E5M2},
// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat32},
// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat16},
// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16},
// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E4M3},
// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E5M2},
} // namespace
// <A_type, B_type, Bias_Type, Gelu_Type D_type>, <m, k, n>
class GEMMTestSuite
:public ::testing::TestWithParam<std::tuple<
std::tuple<size_t, size_t, size_t>, bool, bool>>{};
float ref_gelu(float x){
float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf;
}
template <typename A_Type, typename B_Type, typename Bias_Type, typename Gelu_Type, typename D_Type>
void compute_ref(
const A_Type* a_data,
const B_Type* b_data,
const float a_scale_inv,
const float b_scale_inv,
const Bias_Type* bias_data, //bias is of dim m
const float d_scale,
size_t m, size_t k, size_t n,
D_Type* ref_d_data,
float* ref_d_amax,
Gelu_Type* ref_gelu_data){
*ref_d_amax = 0;
for(size_t ii = 0; ii < m; ii++){
for(size_t jj = 0; jj < n; jj++){
float val = 0;
for(size_t kk = 0; kk < k; kk++){
val += a_scale_inv*b_scale_inv*((float)a_data[ii + kk*m])*((float)b_data[kk + jj*k]);
}
if(bias_data){
val += (float)bias_data[ii];
}
if(ref_gelu_data){
ref_gelu_data[ii + jj*m] = (Gelu_Type)(val);
val = ref_gelu(val);
}
ref_d_data[ii+jj*m] = (D_Type)(val*d_scale);
// update ref_d_amax if in fp8
DType dtype = TypeInfo<D_Type>::dtype;
if(isFp8Type(dtype)){
*ref_d_amax = std::max<float>(*ref_d_amax, std::fabs(val));
}
}
}
}
template <typename A_Type, typename B_Type, typename Bias_Type, typename Gelu_Type, typename D_Type>
void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, const size_t n) {
DType atype = TypeInfo<A_Type>::dtype;
DType btype = TypeInfo<B_Type>::dtype;
DType bias_type = TypeInfo<Bias_Type>::dtype;
DType gelu_type = TypeInfo<Gelu_Type>::dtype;
DType dtype = TypeInfo<D_Type>::dtype;
// pytorch tensor storage is row-major while cublas/rocblas is column-major
Tensor A({ k, m }, atype);
Tensor B({ n, k }, btype);
Tensor D({ n, m }, dtype);
Tensor bias;
if(use_bias){
bias = Tensor({m}, bias_type);
}
Tensor pre_gelu_out;
if(use_gelu){
pre_gelu_out = Tensor({ n, m }, gelu_type);
}
//initialize the data and scale inv of A, B
fillUniform(&A);
fillUniform(&B);
if(use_bias){
fillUniform(&bias);
}
//initialize the scale of D
if(isFp8Type(dtype)){
setRandomScale(&D);
}
bool transa = false;
bool transb = false;
bool grad = false;
bool accumulate = false;
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
#ifdef __HIP_PLATFORM_AMD__
if ((isFp8Type(atype) || isFp8Type(btype)) &&
!(prop.major == 9 && prop.minor >= 4))
{
GTEST_SKIP() << "FP8 is not supported on this HW";
}
#endif
Tensor Workspace({ 33554432 }, DType::kByte);
//perform the gemm in GPU
nvte_cublas_gemm(A.data(),
B.data(),
D.data(),
bias.data(),
pre_gelu_out.data(),
transa,
transb,
grad,
Workspace.data(),
accumulate,
false,
prop.multiProcessorCount,
//default stream
0);
//copy the output results from GPU memory to CPU memory
D.to_cpu();
if(use_gelu){
pre_gelu_out.to_cpu();
}
//perform the gemm in CPU
std::unique_ptr<D_Type[]> ref_D = std::make_unique<D_Type[]>(m*n);
std::unique_ptr<Gelu_Type[]> ref_pre_gelu_out;
if(use_gelu){
ref_pre_gelu_out = std::make_unique<Gelu_Type[]>(m*n);
}
float ref_amax_d;
compute_ref<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(
A.cpu_dptr<A_Type>(),
B.cpu_dptr<B_Type>(),
A.scale_inv(),
B.scale_inv(),
use_bias? bias.cpu_dptr<Bias_Type>(): nullptr,
D.scale(),
m, k, n,
ref_D.get(),
&ref_amax_d,
use_gelu? ref_pre_gelu_out.get(): nullptr);
// check if error message happens in running
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
//compare results
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(dtype)) {
compareResults("D_amax", D.amax(), ref_amax_d, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(dtype);
//relax for certain prime number gemm
if (dtype == DType::kFloat32) {
atol = 1e-5;
}
compareResults("D", D, ref_D.get(), atol, rtol);
if(use_gelu){
auto [atol, rtol] = getTolerances(gelu_type);
//relax for certain prime number gemm
if (dtype == DType::kFloat32) {
atol = 5e-6;
}
compareResults("gelu", pre_gelu_out, ref_pre_gelu_out.get(), atol, rtol);
}
}
using fp32=float;
using fp8=fp8e4m3;
using bf8=fp8e5m2;
TEST_P(GEMMTestSuite, Testfp32xfp32xfp32xfp32xfp32) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = fp32;
using B_Type = fp32;
using Bias_Type = fp32;
using Gelu_Type = fp32;
using D_Type = fp32;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testfp16xfp16xfp16xfp16xfp16) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = fp16;
using B_Type = fp16;
using Bias_Type = fp16;
using Gelu_Type = fp16;
using D_Type = fp16;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testbf16xbf16xbf16xbf16xbf16) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = bf16;
using B_Type = bf16;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = bf16;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testfp8xfp8xbf16xbf16xfp32) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = fp8;
using B_Type = fp8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = fp32;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testfp8xfp8xbf16xbf16xfp16) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = fp8;
using B_Type = fp8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = fp16;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testfp8xfp8xbf16xbf16xbf16) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = fp8;
using B_Type = fp8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = bf16;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testfp8xfp8xbf16xbf16xfp8) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = fp8;
using B_Type = fp8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = fp8;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testfp8xfp8xbf16xbf16xbf8) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = fp8;
using B_Type = fp8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = bf8;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testfp8xbf8xbf16xbf16xfp32) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = fp8;
using B_Type = bf8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = fp32;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testfp8xbf8xbf16xbf16xfp16) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = fp8;
using B_Type = bf8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = fp16;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testfp8xbf8xbf16xbf16xbf16) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = fp8;
using B_Type = bf8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = bf16;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testfp8xbf8xbf16xbf16xfp8) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = fp8;
using B_Type = bf8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = fp8;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testfp8xbf8xbf16xbf16xbf8) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = fp8;
using B_Type = bf8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = bf8;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testbf8xfp8xbf16xbf16xfp32) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = bf8;
using B_Type = fp8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = fp32;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testbf8xfp8xbf16xbf16xfp16) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = bf8;
using B_Type = fp8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = fp16;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testbf8xfp8xbf16xbf16xbf16) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = bf8;
using B_Type = fp8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = bf16;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testbf8xfp8xbf16xbf16xfp8) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = bf8;
using B_Type = fp8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = fp8;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
TEST_P(GEMMTestSuite, Testbf8xfp8xbf16xbf16xbf8) {
using namespace transformer_engine;
using namespace test;
const size_t m = std::get<0>(std::get<0>(GetParam()));
const size_t k = std::get<1>(std::get<0>(GetParam()));
const size_t n = std::get<2>(std::get<0>(GetParam()));
const bool use_bias = std::get<1>(GetParam());
const bool use_gelu = std::get<2>(GetParam());
using A_Type = bf8;
using B_Type = fp8;
using Bias_Type = bf16;
using Gelu_Type = bf16;
using D_Type = bf8;
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
GEMMTestSuite,
::testing::Combine(
::testing::ValuesIn(test_case_sizes),
::testing::Values(false, true), //use bias
::testing::Values(false, true)), //use_gelu
[](const testing::TestParamInfo<GEMMTestSuite::ParamType>& info) {
std::string name = std::to_string(std::get<0>(std::get<0>(info.param))) + "X" +
std::to_string(std::get<1>(std::get<0>(info.param))) + "X" +
std::to_string(std::get<2>(std::get<0>(info.param))) + "X" +
std::to_string(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param));
return name;
});
\ No newline at end of file
...@@ -55,7 +55,11 @@ void compute_ref_stats(NormType norm_type, ...@@ -55,7 +55,11 @@ void compute_ref_stats(NormType norm_type,
current = static_cast<compute_t>(data[i * H + j]); current = static_cast<compute_t>(data[i * H + j]);
sum_sq += (current - m) * (current - m); sum_sq += (current - m) * (current - m);
} }
#ifdef __HIP_PLATFORM_AMD__
rsigma[i] = 1.0/sqrtf((sum_sq / H) + epsilon);
#else
rsigma[i] = rsqrtf((sum_sq / H) + epsilon); rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
#endif
} }
} }
......
...@@ -481,8 +481,13 @@ void compareResults_sequential(const std::string &name, const Tensor &test, ...@@ -481,8 +481,13 @@ void compareResults_sequential(const std::string &name, const Tensor &test,
const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>(); const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
const T *ref_data = reinterpret_cast<const T*>(ref); const T *ref_data = reinterpret_cast<const T*>(ref);
for (size_t i = 0; i < N; ++i) { for (size_t i = 0; i < N; ++i) {
#ifndef __HIP_PLATFORM_AMD__
double t = static_cast<double>(test_data[i]); double t = static_cast<double>(test_data[i]);
double r = static_cast<double>(ref_data[i]); double r = static_cast<double>(ref_data[i]);
#else
double t = static_cast<double>(static_cast<float>(test_data[i]));
double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */ /* For Float32 the floating point comparison is enough to error out */
bool assertion = mismatch && test.dtype() == DType::kFloat32; bool assertion = mismatch && test.dtype() == DType::kFloat32;
...@@ -492,9 +497,19 @@ void compareResults_sequential(const std::string &name, const Tensor &test, ...@@ -492,9 +497,19 @@ void compareResults_sequential(const std::string &name, const Tensor &test,
const double mean = (t + r) / 2; const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
#ifndef __HIP_PLATFORM_AMD__
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p)); const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m)); const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
#else
const double cast_mean_p = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_p))));
const double cast_mean_m = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_m))));
#endif
#ifdef __HIP_PLATFORM_AMD__
assertion = !(cast_mean_m == std::min<double>(t,r) && cast_mean_p == std::max<double>(t,r));
#else
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
#endif
} }
std::string direction = rowwise ? "rowwise" : "columnwise"; std::string direction = rowwise ? "rowwise" : "columnwise";
ASSERT_FALSE(assertion) << "Error in tensor " << name << " in " ASSERT_FALSE(assertion) << "Error in tensor " << name << " in "
...@@ -518,8 +533,14 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con ...@@ -518,8 +533,14 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con
continue; continue;
} }
#ifndef __HIP_PLATFORM_AMD__
double t = static_cast<double>(test_data[i]); double t = static_cast<double>(test_data[i]);
double r = static_cast<double>(ref_data[i]); double r = static_cast<double>(ref_data[i]);
#else
double t = static_cast<double>(static_cast<float>(test_data[i]));
double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */ /* For Float32 the floating point comparison is enough to error out */
...@@ -530,9 +551,19 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con ...@@ -530,9 +551,19 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con
const double mean = (t + r) / 2; const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
#ifndef __HIP_PLATFORM_AMD__
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p)); const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m)); const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
#else
const double cast_mean_p = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_p))));
const double cast_mean_m = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_m))));
#endif
#ifdef __HIP_PLATFORM_AMD__
assertion = !(cast_mean_m == std::min<double>(t,r) && cast_mean_p == std::max<double>(t,r));
#else
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
#endif
} }
if (assertion && i < first_mismatch_idx) { if (assertion && i < first_mismatch_idx) {
first_mismatch_idx = i; first_mismatch_idx = i;
......
...@@ -11,10 +11,16 @@ ...@@ -11,10 +11,16 @@
#include <array> #include <array>
#include <random> #include <random>
#include <cuda_runtime_api.h>
#ifndef USE_ROCM
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <cuda_runtime_api.h> #else
#include <hip/hip_bf16.h>
#include "amd_detail/hip_float8.h"
#endif
#include <cuda_fp16.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include "util/logging.h" #include "util/logging.h"
...@@ -50,9 +56,15 @@ using int32 = int32_t; ...@@ -50,9 +56,15 @@ using int32 = int32_t;
using int64 = int64_t; using int64 = int64_t;
using fp32 = float; using fp32 = float;
using fp16 = half; using fp16 = half;
#ifndef USE_ROCM
using bf16 = nv_bfloat16; using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else
using bf16 = __hip_bfloat16;
using fp8e4m3 = hip_f8<hip_f8_type::fp8>;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>;
#endif
using fp8e8m0 = uint8_t; using fp8e8m0 = uint8_t;
template <typename T> template <typename T>
......
...@@ -24,6 +24,7 @@ from transformer_engine.common.recipe import ( ...@@ -24,6 +24,7 @@ from transformer_engine.common.recipe import (
) )
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
from run_layer_with_overlap import _compare_tensors from run_layer_with_overlap import _compare_tensors
from torch.utils.cpp_extension import IS_HIP_EXTENSION
SEQ_LEN, BATCH_SIZE = 16, 16 SEQ_LEN, BATCH_SIZE = 16, 16
HIDDEN_SIZE = 64 HIDDEN_SIZE = 64
......
...@@ -27,7 +27,7 @@ import transformer_engine.pytorch.ops as te_ops ...@@ -27,7 +27,7 @@ import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex import transformer_engine_torch as tex
from torch.utils.cpp_extension import IS_HIP_EXTENSION
# Check what quantization schemes are supported # Check what quantization schemes are supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -687,7 +687,7 @@ def _test_fp8_scale_update( ...@@ -687,7 +687,7 @@ def _test_fp8_scale_update(
"""Expected absmax and FP8 scale""" """Expected absmax and FP8 scale"""
amax = ref.abs().amax() amax = ref.abs().amax()
max_val = { max_val = {
"forward": 448.0, "forward": 448.0 if not IS_HIP_EXTENSION else 240.0,
"backward": 57344.0, "backward": 57344.0,
}[stage] }[stage]
scale = (max_val / amax) / (2**margin) scale = (max_val / amax) / (2**margin)
......
...@@ -12,6 +12,7 @@ from contextlib import contextmanager ...@@ -12,6 +12,7 @@ from contextlib import contextmanager
import pytest import pytest
import torch import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
...@@ -387,8 +388,24 @@ def test_dpa_checkpoint(dtype, model_configs, model): ...@@ -387,8 +388,24 @@ def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing""" """Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
if IS_HIP_EXTENSION:
model_configs_mla = { model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(
8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128
), # self , 0
"mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1
"mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # cross, 1
}
else:
model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig( "mla_1_0": ModelConfig(
8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128 8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128
...@@ -408,10 +425,10 @@ model_configs_mla = { ...@@ -408,10 +425,10 @@ model_configs_mla = {
"mla_3_1": ModelConfig( "mla_3_1": ModelConfig(
8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference ), # inference
} }
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(not IS_HIP_EXTENSION and get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla]) @pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys()) @pytest.mark.parametrize("model", model_configs_mla.keys())
...@@ -592,7 +609,7 @@ model_configs_swa = { ...@@ -592,7 +609,7 @@ model_configs_swa = {
} }
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.") @pytest.mark.skipif((not IS_HIP_EXTENSION) and (not FlashAttentionUtils.v2_3_plus), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa]) @pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys()) @pytest.mark.parametrize("model", model_configs_swa.keys())
...@@ -614,7 +631,7 @@ model_configs_alibi_slopes = { ...@@ -614,7 +631,7 @@ model_configs_alibi_slopes = {
} }
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.") @pytest.mark.skipif((not IS_HIP_EXTENSION) and (not FlashAttentionUtils.v2_3_plus), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes]) @pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys()) @pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
...@@ -1130,11 +1147,16 @@ def test_transformer_layer( ...@@ -1130,11 +1147,16 @@ def test_transformer_layer(
tols = dict(atol=5e-2, rtol=5e-2) tols = dict(atol=5e-2, rtol=5e-2)
workspace_opt = True workspace_opt = True
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd"
# override the qkv_layout in mqa gqa mode in ROCm TE
if IS_HIP_EXTENSION and model_configs[model].num_gqa_groups != model_configs[model].num_heads:
qkv_layout = "sbhd_sbhd_sbhd"
# Test backend availability # Test backend availability
available_backends, fused_attn_backends = _get_attention_backends( available_backends, fused_attn_backends = _get_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd", qkv_layout=qkv_layout,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
...@@ -1434,7 +1456,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): ...@@ -1434,7 +1456,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
) )
) )
@pytest.mark.skipif(IS_HIP_EXTENSION, reason="FP8 Fused attention is not supported on ROCm")
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
...@@ -1641,7 +1663,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP ...@@ -1641,7 +1663,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
return out, param_names, tuple(x.grad for x in params) return out, param_names, tuple(x.grad for x in params)
return out, param_names, tuple(None for x in params) return out, param_names, tuple(None for x in params)
@pytest.mark.skipif(IS_HIP_EXTENSION, reason="FP8 Fused attention is not supported on ROCm")
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
...@@ -1900,7 +1922,7 @@ cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1")) ...@@ -1900,7 +1922,7 @@ cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"] models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"] models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
@pytest.mark.skipif(IS_HIP_EXTENSION, reason="FP8 Fused attention is not supported on ROCm")
@pytest.mark.skipif( @pytest.mark.skipif(
( (
get_cudnn_version() < (8, 9, 3) get_cudnn_version() < (8, 9, 3)
......
...@@ -13,6 +13,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -13,6 +13,7 @@ from transformer_engine.pytorch.utils import (
) )
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils
from test_fused_attn import ModelConfig from test_fused_attn import ModelConfig
from torch.utils.cpp_extension import IS_HIP_EXTENSION
model_configs_flash_attn = { model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
...@@ -51,7 +52,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): ...@@ -51,7 +52,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
@pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") @pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.skipif(not IS_HIP_EXTENSION and get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
...@@ -111,7 +112,7 @@ model_configs_fused_attn = { ...@@ -111,7 +112,7 @@ model_configs_fused_attn = {
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.skipif(IS_HIP_EXTENSION or get_device_compute_capability() < (8, 0), reason="DTK not surpport fused attn for now, CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
......
...@@ -23,7 +23,10 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager ...@@ -23,7 +23,10 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common import recipe from transformer_engine.common import recipe
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
import os
from functools import cache
# Check if FP8 is supported. # Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -73,6 +76,15 @@ def reset_global_fp8_state(): ...@@ -73,6 +76,15 @@ def reset_global_fp8_state():
yield yield
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
if IS_HIP_EXTENSION:
@cache
def use_hipblaslt() -> bool:
return (os.getenv("NVTE_USE_HIPBLASLT") is not None
or os.getenv("NVTE_USE_ROCBLAS") is None )
@pytest.fixture(autouse=True)
def skip_rocblas():
if not use_hipblaslt():
pytest.skip("CUDA graph capture not supported with rocBLAS path")
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool: def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
"""Check that two lists of tensors match exactly.""" """Check that two lists of tensors match exactly."""
......
...@@ -10,6 +10,7 @@ from typing import Optional ...@@ -10,6 +10,7 @@ from typing import Optional
import pytest import pytest
import torch import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine import transformer_engine
import transformer_engine.common.recipe import transformer_engine.common.recipe
...@@ -27,6 +28,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8 ...@@ -27,6 +28,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex import transformer_engine_torch as tex
if IS_HIP_EXTENSION:
import os
from functools import cache
@cache
def use_hipblaslt() -> bool:
return (os.getenv("NVTE_USE_HIPBLASLT") is not None
or os.getenv("NVTE_USE_ROCBLAS") is None )
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
...@@ -770,6 +779,9 @@ class TestBasicOps: ...@@ -770,6 +779,9 @@ class TestBasicOps:
pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs") pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs")
if quantization == "mxfp8" and quantized_grad_input: if quantization == "mxfp8" and quantized_grad_input:
pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs") pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs")
if ( IS_HIP_EXTENSION and not use_hipblaslt() and
accumulate_into_main_grad and dtype != torch.float32 and not quantized_compute):
pytest.skip("Parameters combination is not supported by ROCBLAS")
# Random data # Random data
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
......
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
# License for AMD contributions = MIT. See LICENSE for more information
import os, sys
import copy
import pytest
import tempfile
import shutil
import subprocess
import csv
import warnings
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.cpp_extensions import gemm
from transformer_engine.pytorch.module.base import get_workspace
def use_hipblaslt():
return (os.getenv("NVTE_USE_HIPBLASLT") is not None
or os.getenv("NVTE_USE_ROCBLAS") is None )
storage_fname = "te_algo"
def dump_storage(fname):
print("========")
with open(fname, "r") as ifile:
for row in ifile:
print(row)
print("========")
def analyse_storage(fname):
with open(fname, "r") as ifile:
reader = csv.DictReader(ifile)
next(reader)
head = reader.fieldnames
assert ("m" in head and "algo_id" in head and "ws_min" in head and "ws_max" in head
and "aidx" in head), "Invalid CSV format"
return head
def read_storage(fname):
data = []
with open(fname, "r") as ifile:
reader = csv.DictReader(ifile)
for row in reader:
data.append(row)
return data
def write_storage(fname, head, data):
with open(fname, "w") as ofile:
writer = csv.DictWriter(ofile, fieldnames = head, lineterminator="\n")
writer.writeheader()
writer.writerows(data)
@pytest.mark.skipif(not use_hipblaslt(), reason="Autotune requires hipBLASLt")
@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="Autotune requires ROCm TE")
def test_gemm_autotune():
storage_dir = tempfile.mkdtemp();
fname = storage_dir+"/"+storage_fname
script = os.path.abspath(__file__)
try:
os.environ["TE_HIPBLASLT_ALGO_LOAD"] = fname
os.environ["TE_HIPBLASLT_ALGO_SAVE"] = fname
run_args = ["python", script, "--run"]
#Initial algo creation
subprocess.run(run_args)
head = analyse_storage(fname)
algos = read_storage(fname)
assert len(algos)==1, "Expected 1 cached record"
algo0 = copy.copy(algos[0])
ofile = fname+".1"
os.environ["TE_HIPBLASLT_ALGO_SAVE"] = ofile
#Unused cache entries
algos[0]["m"] = "999"+algos[0]["m"] # fake record for different shape
write_storage(fname, head, algos)
subprocess.run(run_args)
algos = read_storage(ofile)
assert len(algos)==2, "Expected 2 cached records"
assert algo0 == algos[1], "Invalid algo"
#Adjust workspace size
ws_max = int(algo0["ws_max"])
if (ws_max > 0):
algos=[copy.copy(algo0)]
algos[0]["ws_max"] = str(ws_max - 1) # decrease WS range should restore size
ws_min = int(algos[0]["ws_min"])
if (ws_max - ws_min > 1):
ws_min = ws_min + 1
algos[0]["ws_min"] = str(ws_min)
write_storage(fname, head, algos)
subprocess.run(run_args)
algos = read_storage(ofile)
assert len(algos)==1, "Expected 1 cached record"
assert (str(ws_min), str(ws_max)) == (algos[0]["ws_min"], algos[0]["ws_max"]), "Invalid WS size"
else:
warnings.warn("Cached algo Workspace size is 0")
#Modify algo index
algo_index = int(algo0["aidx"])
algos=[copy.copy(algo0)]
algos[0]["aidx"] = str(algo_index + 1);
write_storage(fname, head, algos)
subprocess.run(run_args)
algos = read_storage(ofile)
assert len(algos)==1, "Expected 1 cached record"
assert (algo0["aidx"], algo0["algo_id"]) == (algos[0]["aidx"], algos[0]["algo_id"]), "Invalid algo IDX"
# Configure autotune range so current cached algo is out of it
# and cache new value
os.environ["TE_HIPBLASLT_ALGO_LOAD"] = ""
os.environ["TE_HIPBLASLT_ALGO_SAVE"] = fname
os.environ["TE_HIPBLASLT_ALGO_SELECTION"] = str(algo_index + 1)
subprocess.run(run_args)
algos = read_storage(fname)
assert len(algos)==1, "Expected 1 cached record"
algo1 = copy.copy(algos[0])
assert algo0["algo_id"] != algo1["algo_id"], "Unexpected algo ID"
#Restore autotune range begining, the new algo should still be used
os.environ["TE_HIPBLASLT_ALGO_LOAD"] = fname
del os.environ["TE_HIPBLASLT_ALGO_SELECTION"]
subprocess.run(run_args)
algos = read_storage(fname)
assert len(algos)==1, "Expected 1 cached record"
assert algo1 == algos[0], "Invalid algo ID"
finally:
shutil.rmtree(storage_dir)
pass
def run_gemm():
N = 32
datatype = torch.float16
inp = torch.randn((N, N), device="cuda", dtype=datatype)
_, _, _ = gemm(A=inp, B=inp, dtype=datatype, workspace=get_workspace())
if __name__ == "__main__":
if sys.argv[1] == "--run":
run_gemm()
...@@ -12,6 +12,7 @@ import random ...@@ -12,6 +12,7 @@ import random
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import Parameter from torch.nn import Parameter
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.fp8 import ( from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager, FP8GlobalStateManager,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment