Commit c520cba3 authored by yuguo's avatar yuguo
Browse files

[DCU] Preliminary adaptation

parent 5b6ef054
......@@ -6,6 +6,7 @@ from typing import Iterable, Optional
import pytest
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
......@@ -258,7 +259,7 @@ class TestFP8Recipe:
# Compute scale
max_val = {
"forward": 448.0,
"forward": 448.0 if not IS_HIP_EXTENSION else 240.0,
"backward": 57344.0,
}[stage]
ref_scale = (max_val / ref_amax) / (2**margin)
......
......@@ -9,6 +9,7 @@ from contextlib import nullcontext
import torch
import pytest
import os
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.fp8 import (
fp8_autocast,
......@@ -51,6 +52,12 @@ def create_meta(scale_factor: float, size: int = 1):
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
return meta
if IS_HIP_EXTENSION:
from functools import cache
@cache
def use_hipblaslt() -> bool:
return (os.getenv("NVTE_USE_HIPBLASLT") is not None
or os.getenv("NVTE_USE_ROCBLAS") is None )
def custom_amax_to_scale(
amax: torch.Tensor,
......@@ -982,6 +989,10 @@ def test_sanity_gradient_accumulation_fusion(
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
if IS_HIP_EXTENSION:
if not use_hipblaslt():
pytest.skip("CUDA graph capture not supported with rocBLAS path")
config = model_configs[model]
if fp8_recipe is not None:
......
......@@ -4,44 +4,100 @@
cmake_minimum_required(VERSION 3.21)
# Language options
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
option(USE_ROCM "Use ROCm" OFF)
option(USE_HIPBLASLT "Use HIPBLASLT" ON)
# Temp unsupport aottriton\ck backend and Use ROCBLAS
option(USE_ROCBLAS "Use ROCBLAS" OFF)
if(NOT USE_ROCM)
if(((EXISTS "/opt/dtk/") OR (EXISTS $ENV{ROCM_PATH})) AND NOT (EXISTS "/bin/nvcc"))
message("hcu detected.")
set(USE_ROCM ON)
endif()
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G")
endif()
# Hide non-necessary symbols in shared object.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version")
if (USE_ROCM)
add_compile_definitions(__HIP_CLANG_ONLY__=1)
if (NOT USE_HIPBLASLT AND NOT USE_ROCBLAS)
message(FATAL_ERROR "Need specify at least one GEMM library to use: HIPBLASLT or ROCBLAS")
endif()
unset(USE_CUDA)
else()
set(USE_CUDA TRUE)
endif()
# Transformer Engine library
project(transformer_engine LANGUAGES CUDA CXX)
# CUDA Toolkit
find_package(CUDAToolkit REQUIRED)
if (CUDAToolkit_VERSION VERSION_LESS 12.0)
message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}")
# Language options
if(USE_CUDA)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G")
endif()
# Hide non-necessary symbols in shared object.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version")
# Transformer Engine library
project(transformer_engine LANGUAGES CUDA CXX)
# CUDA Toolkit
find_package(CUDAToolkit REQUIRED)
if (CUDAToolkit_VERSION VERSION_LESS 12.0)
message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}")
endif()
# cuDNN frontend API
set(CUDNN_FRONTEND_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include")
if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find cuDNN frontend API at ${CUDNN_FRONTEND_INCLUDE_DIR}. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
endif()
include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
else()
set(CMAKE_CXX_STANDARD 17)
project(transformer_engine LANGUAGES HIP CXX)
# Disable Asserts In Code (Can't use asserts on HIP stack.)
add_definitions(-DNDEBUG)
add_definitions(-DUSE_ROCM)
# Change clang++ to hipcc
SET(CMAKE_CXX_COMPILER "${ROCM_PATH}/bin/hipcc")
if(NOT DEFINED ENV{NVTE_ROCM_ARCH})
SET(CMAKE_HIP_ARCHITECTURES gfx906;gfx926;gfx928;gfx936)
else()
SET(CMAKE_HIP_ARCHITECTURES $ENV{NVTE_ROCM_ARCH})
endif()
# build error will be dup-ed parallel-jobs times
# set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -parallel-jobs=4")
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -g")
endif()
list(APPEND CMAKE_MODULE_PATH "/opt/dtk")
endif()
# cuDNN frontend API
set(CUDNN_FRONTEND_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include")
if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find cuDNN frontend API at ${CUDNN_FRONTEND_INCLUDE_DIR}. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
set(message_line "-------------------------------------------------------------")
message("${message_line}")
message(STATUS "USE_ROCM ${USE_ROCM}")
if(USE_ROCM)
message(STATUS "CMAKE_HIP_ARCHITECTURES: ${CMAKE_HIP_ARCHITECTURES}")
message(STATUS "USE_HIPBLASLT ${USE_HIPBLASLT} USE_ROCBLAS ${USE_ROCBLAS}")
endif()
include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
......@@ -49,60 +105,159 @@ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
# Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES
cudnn_utils.cpp
transformer_engine.cpp
common.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
activation/gelu.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
activation/swiglu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/rtc.cpp
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
if(USE_CUDA)
list(APPEND transformer_engine_SOURCES
cudnn_utils.cpp
transformer_engine.cpp
common.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
activation/gelu.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
activation/swiglu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/rtc.cpp
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
else()
list(APPEND transformer_engine_SOURCES
cudnn_utils.cpp
transformer_engine.cpp
common.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
gemm/cublaslt_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/rtc.cpp
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
# process source code files
message("${message_line}")
message(STATUS "CMAKE_CURRENT_SOURCE_DIR: ${CMAKE_CURRENT_SOURCE_DIR}")
message(STATUS "PROJECT_SOURCE_DIR: ${PROJECT_SOURCE_DIR}")
set(TE ${CMAKE_CURRENT_SOURCE_DIR}/../..)
set(THIRDPARTY ${TE}/3rdparty)
list(APPEND CMAKE_MODULE_PATH "${THIRDPARTY}/hipify_torch/cmake")
include(Hipify)
message(STATUS "CMAKE_MODULE_PATH: ${CMAKE_MODULE_PATH}")
set(header_include_dir
${CMAKE_CURRENT_SOURCE_DIR}/comm_gemm_overlap/userbuffers
${CMAKE_CURRENT_SOURCE_DIR}/activation
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/transpose
${CMAKE_CURRENT_SOURCE_DIR}/util
${CMAKE_CURRENT_SOURCE_DIR}/normalization
${CMAKE_CURRENT_SOURCE_DIR}/normalization/rmsnorm
${CMAKE_CURRENT_SOURCE_DIR}/normalization/layernorm
${CMAKE_CURRENT_SOURCE_DIR})
message(STATUS "HIPIFY CUDA_SOURCE_DIR: ${CMAKE_CURRENT_SOURCE_DIR}")
message(STATUS "HIPIFY HEADER_INCLUDE_DIR: ${header_include_dir}")
hipify(CUDA_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}
HEADER_INCLUDE_DIR ${header_include_dir}
IGNORES "*/amd_detail/*"
IGNORES "*/fused_attn/*"
CUSTOM_MAP_FILE "${TE}/hipify_custom_map.json"
)
get_hipified_list("${transformer_engine_SOURCES}" te_hip_sources)
message("${message_line}")
message(STATUS "nvte hipified sources: ${te_hip_sources}")
add_library(transformer_engine SHARED ${te_hip_sources})
endif()
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
if (USE_CUDA)
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
else()
# Aotriton is currently unsupported
set(AotritonAndCk_fused_attn "unsupported")
find_package(hip)
list(APPEND transformer_engine_LINKER_LIBS hip::host hip::device roctx64)
if(USE_HIPBLASLT)
find_package(hipblaslt)
find_package(hipblas REQUIRED PATHS ${ROCM_PATH})
target_compile_definitions(transformer_engine PUBLIC USE_HIPBLASLT)
list(APPEND transformer_engine_LINKER_LIBS roc::hipblaslt hipblas)
endif()
if(USE_ROCBLAS)
find_package(rocblas)
target_compile_definitions(transformer_engine PUBLIC USE_ROCBLAS)
list(APPEND transformer_engine_LINKER_LIBS roc::rocblas)
endif()
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
endif()
# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
......@@ -113,8 +268,10 @@ if (NVTE_UB_WITH_MPI)
target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI)
endif()
# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
if (USE_CUDA)
# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
endif()
# Helper functions to make header files with C++ strings
function(make_string_header STRING STRING_NAME)
......@@ -130,17 +287,34 @@ function(make_string_header_from_file file_ STRING_NAME)
endfunction()
# Header files with C++ strings
list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path)
make_string_header("${cuda_include_path}"
string_path_cuda_include)
make_string_header_from_file(transpose/rtc/cast_transpose_fusion.cu
string_code_transpose_rtc_cast_transpose_fusion_cu)
make_string_header_from_file(transpose/rtc/cast_transpose.cu
string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.cu
string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(utils.cuh
string_code_utils_cuh)
if(USE_CUDA)
list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path)
make_string_header("${cuda_include_path}"
string_path_cuda_include)
make_string_header_from_file(transpose/rtc/cast_transpose_fusion.cu
string_code_transpose_rtc_cast_transpose_fusion_cu)
make_string_header_from_file(transpose/rtc/cast_transpose.cu
string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.cu
string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(utils.cuh
string_code_utils_cuh)
else()
make_string_header_from_file(utils_hip.cuh
string_code_utils_cuh)
make_string_header_from_file(transpose/rtc/cast_transpose_fusion.hip
string_code_transpose_rtc_cast_transpose_fusion_cu)
make_string_header_from_file(transpose/rtc/cast_transpose.hip
string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.hip
string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(amd_detail/hip_float8.h
string_code_amd_detail_hip_float8_h)
make_string_header_from_file(amd_detail/hip_f8_impl.h
string_code_amd_detail_hip_f8_impl_h)
endif()
make_string_header_from_file(util/math.h
string_code_util_math_h)
target_include_directories(transformer_engine PRIVATE
......@@ -160,8 +334,22 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
endif()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
if(USE_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
else()
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -O3")
set(HIP_HCC_FLAGS "${CMAKE_HIP_FLAGS} -mavx2 -mf16c -mfma -std=c++17")
# Ask hcc to generate device code during compilation so we can use
# host linker to link.
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -fno-gpu-rdc -Wno-defaulted-function-deleted")
foreach(rocm_arch ${CMAKE_HIP_ARCHITECTURES})
# if CMAKE_CXX_FLAGS has --offload-arch set already, better to rm first
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} --offload-arch=${rocm_arch}")
endforeach()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${HIP_HCC_FLAGS}")
endif()
# Number of parallel build jobs
if(ENV{MAX_JOBS})
......
......@@ -124,6 +124,9 @@ def _load_nvrtc():
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
try:
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
except OSError:
pass
_TE_LIB_CTYPES = _load_library()
/*************************************************************************
* Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
#include <hip/hip_runtime.h>
namespace hip_f8_impl {
HIP_HOST_DEVICE inline int clz(uint32_t x) {
#ifdef __HIP_DEVICE_COMPILE__
return __clz(x);
#else
return __builtin_clz(x);
#endif
}
template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
HIP_HOST_DEVICE
uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) {
constexpr bool is_half = std::is_same<T,__half>::value;
constexpr bool is_float = std::is_same<T,float>::value;
static_assert(wm+we==7, "wm+we==7");
static_assert(is_half || is_float, "Only half and float can be cast to f8");
const int mfmt = (sizeof(T)==4) ? 23 : 10;
uint32_t x;
if(sizeof(T)==4)
x = reinterpret_cast<uint32_t&>(_x);
else
x = reinterpret_cast<uint16_t&>(_x);
uint32_t y, head, mantissa;
int exponent, bias;
uint32_t sign;
if(sizeof(T)==4) {
head = x & 0xFF800000;
mantissa = x & 0x7FFFFF;
exponent = (head>>23) & 0xFF;
sign = head >> 31;
bias = 127;
} else {
head = x & 0xFC00;
mantissa = x & 0x3FF;
exponent = (head>>10) & 0x1F;
sign = head >> 15;
bias = 15;
}
uint32_t signed_inf = (sign<<7) + (((1<<we)-1)<<wm);
// Deal with inf and NaNs
if(negative_zero_nan) {
if(sizeof(T)==4) {
if((x & 0x7F800000) == 0x7F800000)
return 0x80;
} else {
//if(__hisinf(x) || __hisnan(x))
if((x & 0x7C00)==0x7C00)
return 0x80;
}
}
else {
if(sizeof(T)==4) {
if((x & 0x7F800000) == 0x7F800000)
return signed_inf + (mantissa!=0 ? 1 : 0);
} else {
if((x & 0x7C00)==0x7C00)
return signed_inf + (mantissa!=0 ? 1 : 0);
}
}
if(x==0)
return 0;
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
const int f8_bias = ( 1<<(we-1) ) - 1 + ( negative_zero_nan ? 1 : 0 );
const int f8_denormal_act_exponent = 1 - f8_bias; //actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int act_exponent, f8_exponent, exponent_diff;
if (exponent == 0) { // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 here.
In this case, f8 is usually in denormal. But there could be exceptions.
fp16 denormal has exponent bias 15 while bf8 with NANOO has exponent bias 16.
It means that there are some numbers in fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15.
fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1;
exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal
}
else { // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if (act_exponent <= f8_denormal_act_exponent) {
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = f8_denormal_act_exponent - act_exponent;
}
else { //both fp32/fp16 and f8 are in normal range
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference for this case,
//act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa += (1 << mfmt); //Add the implicit 1 into mantissa
}
bool midpoint;
if (exponent_diff<=wm+1)
// The determinination of midpoint only makes sense when wm+1 could compensate the difference in exponent.
// Why wm+1 instead of wm? It is because in additional to the wm bits to be left as f8 mantissa, there is
// also the implicit 1 (There is not always implicit 1 but it does not matter).
midpoint = (mantissa & ( (1 << (mfmt-wm+exponent_diff)) - 1 )) == ( 1 << (mfmt-wm+exponent_diff-1) );
else
midpoint = false;
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift right
as shift right could rip off some residual part and make something not midpoint look like midpoint.
For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than midpoint,
but after shift right by 4 bits, it would look like midpoint.
*/
if (exponent_diff>0)
// Clip the exponent_diff as if the right shift when exponent_diff > 31 is undefined behavior
mantissa >>= (exponent_diff > 31 ? 31 : exponent_diff);
else if (exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << mfmt);
//if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
f8_exponent = (act_exponent+exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one?0:1);
//Now we have the exponent and mantissa adjusted
uint32_t drop_mask = (1 << (mfmt-wm)) - 1;
//bool midpoint = (mantissa & drop_mask) == ( 1 << (mfmt-wm-1) );
bool odd = mantissa & (1<< (mfmt-wm)); // if the least significant bit that is not truncated is 1
mantissa += (stoch ? rng : (midpoint?(odd?mantissa:mantissa-1 ) :mantissa) ) & drop_mask;
//Now we deal with overflow
if (f8_exponent == 0) {
if ((1 << mfmt) & mantissa) {
f8_exponent = 1; //denormal overflow to become normal, promote exponent
//mantissa &= (1<<mfmt) -1 ; //No need to make 1 implicit now as it will be addressed later
}
}
else {
if ((1 << (mfmt+1)) & mantissa) {
mantissa >>= 1;
f8_exponent++;
//mantissa &= (1<<mfmt) -1 ; // No need to make 1 implicit now as it will be addressed later
}
}
mantissa >>= (mfmt-wm);
// above range: quantize to maximum possible float of the same sign
const int max_exp = (1<<we)-(negative_zero_nan ? 1 : 2);
if(f8_exponent > max_exp) {
if(clip) {
mantissa = (1<<wm)-1;
f8_exponent = max_exp;
} else {
return signed_inf;
}
}
if(f8_exponent == 0 && mantissa == 0)
return negative_zero_nan? 0 : (sign<<7);
mantissa &= (1<<wm)-1;
return (sign << 7) | (f8_exponent << wm) | mantissa;
}
/* RTC does not have std::conditional so implement it here*/
template<bool B, class T, class F>
struct conditional { typedef T type; };
template<class T, class F>
struct conditional<false, T, F> { typedef F type; };
template <int wm, int we, typename T, bool negative_zero_nan>
HIP_HOST_DEVICE
T cast_from_f8(uint8_t x) {
constexpr bool is_half = std::is_same<T,__half>::value;
constexpr bool is_float = std::is_same<T,float>::value;
constexpr bool is_bf16 = std::is_same<T,hip_bfloat16>::value;
static_assert(is_half || is_float, "only half and float are supported");
constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
T fInf, fNegInf, fNaN, fNeg0;
if(is_half) {
const uint16_t ihInf = 0x7C00;
const uint16_t ihNegInf = 0xFC00;
const uint16_t ihNaN = 0x7C01;
const uint16_t ihNeg0 = 0x8000;
fInf = reinterpret_cast<const __half&>(ihInf);
fNegInf = reinterpret_cast<const __half&>(ihNegInf);
fNaN = reinterpret_cast<const __half&>(ihNaN);
fNeg0 = reinterpret_cast<const __half&>(ihNeg0);
} else if(is_float) {
const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000;
fInf = reinterpret_cast<const float&>(ifInf);
fNegInf = reinterpret_cast<const float&>(ifNegInf);
fNaN = reinterpret_cast<const float&>(ifNaN);
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
}
if(x==0)
return 0;
uint32_t sign = x>>7;
uint32_t mantissa = x & ((1<<wm)-1);
int exponent = (x & 0x7F) >> wm;
if(negative_zero_nan) {
if(x==0x80)
return fNaN;
} else {
if(x==0x80)
return fNeg0;
if(exponent == ((1<<we)-1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
}
typename conditional<sizeof(T)==2, uint16_t, uint32_t>::type retval;
if(we==5 && is_half && !negative_zero_nan) {
retval = x<<8;
return reinterpret_cast<const T&>(retval);
}
const int exp_low_cutoff = (1<<(weo-1)) - (1<<(we-1)) + 1 - (negative_zero_nan ? 1 : 0);
//subnormal input
if(exponent == 0) {
//guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + clz(mantissa) - (32-wm);
mantissa <<= sh;
exponent += 1-sh;
/*
exponent++;
while(mantissa<(1<<wm)) {
mantissa <<= 1;
exponent--;
}
*/
mantissa &= ((1<<wm)-1);
}
exponent += exp_low_cutoff-1;
mantissa <<= wmo - wm;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent<=0) {
mantissa |= 1<<wmo;
mantissa >>= 1-exponent;
exponent = 0;
}
if(sizeof(T)==2)
retval = (sign<<15) | (exponent<<10) | mantissa;
else
retval = (sign<<31) | (exponent<<23) | mantissa;
return reinterpret_cast<const T&>(retval);
}
} // namespace hip_f8_impl
\ No newline at end of file
/*************************************************************************
* Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
#pragma once
// FP8 header version 0.3, 2021/05/11
#include <hip/hip_runtime.h>
#define HIP_HOST_DEVICE __host__ __device__
#define HIP_DEVICE __device__
#define HIP_HOST __host__
#define E5M2_AMAX 57344.0
#define E4M3_AMAX 240.0
namespace hip_f8_impl {
template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
HIP_HOST_DEVICE
uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0);
template <int wm, int we, typename T, bool negative_zero_nan>
HIP_HOST_DEVICE
T cast_from_f8(uint8_t x);
} // namespace hip_f8_impl
#include "hip_f8_impl.h"
enum class hip_f8_type {
bf8 = 0, // 1:5:2
fp8 = 1 // 1:4:3
};
enum class hip_f8_rounding_mode {
standard,
stochastic
};
// bias mode bit implementation
//
// For MI100 simulation purpose, we keep a copy of it on the host and device
// (MI300 HW implementation will be different)
//
// The bias mode should only be accessed via its get/set routines.
// The set routine sets both copies to the same value, keeping them in sync
// The get routine will return the device copy for device functions and
// the host copy for host functions
//
// "bias mode optimial"
// => "bias mode bit" = 1
// => bias = 16 for 152, 8 for 143
// => NAN/INF are represented as negative_zero
//
// "bias mode ieee"
// => "bias mode bit" = 0
// => bias = 15 for 152, 7 for 143
// => NAN/INF are represented as per IEEE conventions
#ifndef __HIPCC_RTC__
static bool hip_f8_bias_mode_bit_host = true;
static inline __host__ bool get_hip_f8_bias_mode() {
return hip_f8_bias_mode_bit_host;
}
#endif // __HIPCC_RTC__
#ifdef __HIPCC__
static __device__ bool hip_f8_bias_mode_bit_device = true;
static inline __device__ bool get_hip_f8_bias_mode() {
return hip_f8_bias_mode_bit_device;
}
#ifndef __HIPCC_RTC__
static __global__ void set_hip_f8_bias_mode_bit(bool v) {
hip_f8_bias_mode_bit_device = v;
}
static void set_hip_f8_bias_mode_ieee() {
hipLaunchKernelGGL(set_hip_f8_bias_mode_bit, dim3(1), dim3(1), 0, 0, false);
hip_f8_bias_mode_bit_host = false;
}
static void set_hip_f8_bias_mode_optimal() {
hipLaunchKernelGGL(set_hip_f8_bias_mode_bit, dim3(1), dim3(1), 0, 0, true);
hip_f8_bias_mode_bit_host = true;
}
#endif // __HIPCC_RTC__
#endif // __HIPCC__
template<hip_f8_type T>
struct hip_f8 {
uint8_t data;
// default constructor
HIP_HOST_DEVICE hip_f8() = default;
// constructor from bits
explicit HIP_HOST_DEVICE hip_f8(uint8_t v) {
data = v;
}
// constructor from float
#ifdef __gfx942__
explicit HIP_DEVICE hip_f8(float v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0) {
union {
float fval;
uint32_t i32val;
uint8_t i8val[4];
} val;
uint32_t ival = 0;
val.fval = v;
if (T == hip_f8_type::bf8) { // bf8
if ((val.i32val & 0x7F800000) != 0x7F800000) // propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, E5M2_AMAX, -E5M2_AMAX);
if (rm == hip_f8_rounding_mode::standard) { // RNE rounding
ival = __builtin_amdgcn_cvt_pk_bf8_f32(
val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
data = val.i8val[0];
}
else { //stochastic rounding
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
data = val.i8val[0]; // little endian
}
}
else { // fp8
if ((val.i32val & 0x7F800000) != 0x7F800000) /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, E4M3_AMAX, -E4M3_AMAX);
if (rm == hip_f8_rounding_mode::standard) { // RNE rounding
ival = __builtin_amdgcn_cvt_pk_fp8_f32(
val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
data = val.i8val[0];
}
else { //stochastic rounding
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
data = val.i8val[0]; // little endian
}
}
}
#ifndef __HIPCC_RTC__
explicit HIP_HOST //Code host still uses SW simulated conversion on gfx942
hip_f8(float v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0) {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<2, 5, float, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<2, 5, float, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<3, 4, float, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<3, 4, float, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
}
}
#endif
#else // #ifndef __gfx942__
explicit HIP_HOST_DEVICE // On architectures other than gfx942, both host and device still use SW simulated conversion
hip_f8(float v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0) {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<2, 5, float, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<2, 5, float, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<3, 4, float, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<3, 4, float, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
}
}
#endif // #ifdef __gfx942__
// constructor from half
#ifdef __gfx942__
explicit HIP_DEVICE hip_f8(half v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0)
: hip_f8((float)v, rm, rng)
{
}
#ifndef __HIPCC_RTC__
explicit HIP_HOST //Code host still uses SW simulated conversion on gfx942
hip_f8(half v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0) {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<2, 5, half, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<2, 5, half, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<3, 4, half, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<3, 4, half, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
}
}
#endif
#else // #ifndef __gfx942__
explicit HIP_HOST_DEVICE // On architectures other than gfx942, both host and device still use SW simulated conversion
hip_f8(half v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0) {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<2, 5, half, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<2, 5, half, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<3, 4, half, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<3, 4, half, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
}
}
#endif // #ifdef __gfx942__
// constructor from hip_bfloat16
explicit HIP_HOST_DEVICE hip_f8(hip_bfloat16 v, hip_f8_rounding_mode r=hip_f8_rounding_mode::standard, uint32_t rng=0);
// convert to float
#ifdef __gfx942__
HIP_DEVICE operator float() const {
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // dependent of endian
} val;
// assign 8bit data in position [7:0]
val.i32val = 0;
val.i8val[3] = data; // little endian
// upcast
if(T == hip_f8_type::bf8)
val.fval = __builtin_amdgcn_cvt_f32_bf8(val.i32val, 3); // 0 pos
else // fp8
val.fval = __builtin_amdgcn_cvt_f32_fp8(val.i32val, 3); // 0 pos
return val.fval;
}
#ifndef __HIPCC_RTC__
explicit inline HIP_HOST //Code host still uses SW simulated conversion on gfx942
operator float() const {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<2, 5, float, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<2, 5, float, false/*negative_zero_nan*/>(data);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<3, 4, float, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<3, 4, float, false/*negative_zero_nan*/>(data);
}
}
}
#endif
#else // #ifdef __gfx942__
explicit inline HIP_HOST_DEVICE // On architectures other than gfx942, both host and device still use SW simulated conversion
operator float() const {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<2, 5, float, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<2, 5, float, false/*negative_zero_nan*/>(data);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<3, 4, float, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<3, 4, float, false/*negative_zero_nan*/>(data);
}
}
}
#endif // #ifdef __gfx942__
// convert to half
#ifdef __gfx942__
explicit HIP_DEVICE inline operator half() const {
return __half(float(*this));
}
#ifndef __HIPCC_RTC__
explicit inline HIP_HOST //Code host still uses SW simulated conversion on gfx942
operator half() const {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<2, 5, half, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<2, 5, half, false/*negative_zero_nan*/>(data);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<3, 4, half, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<3, 4, half, false/*negative_zero_nan*/>(data);
}
}
}
#endif
#else // #ifndef __gfx942__
explicit inline HIP_HOST_DEVICE // On architectures other than gfx942, both host and device still use SW simulated conversion
operator half() const {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<2, 5, half, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<2, 5, half, false/*negative_zero_nan*/>(data);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<3, 4, half, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<3, 4, half, false/*negative_zero_nan*/>(data);
}
}
}
#endif // #ifdef __gfx942__
// convert to hip_bfloat16
explicit inline HIP_HOST_DEVICE operator hip_bfloat16() const;
// check for zero
inline HIP_HOST_DEVICE bool is_zero() const {
if (get_hip_f8_bias_mode()) {
return data == 0x00;
} else {
return (data == 0x00) || (data == 0x80);
}
}
// check for nan
inline HIP_HOST_DEVICE bool is_nan() const {
if (get_hip_f8_bias_mode()) {
return data == 0x80;
} else {
if (T == hip_f8_type::bf8) {
return
(data == 0x7d) || (data == 0x7e) || (data == 0x7f) ||
(data == 0xfd) || (data == 0xfe) || (data == 0xff);
} else {
return
(data == 0x79) || (data == 0x7a) || (data == 0x7b) || (data == 0x7c) || (data == 0x7d) || (data == 0x7e) || (data == 0x7f) ||
(data == 0xf9) || (data == 0xfa) || (data == 0xfb) || (data == 0xfc) || (data == 0xfd) || (data == 0xfe) || (data == 0xff);
}
}
}
// check for inf
inline HIP_HOST_DEVICE bool is_inf() const {
if (get_hip_f8_bias_mode()) {
return data == 0x80;
} else {
if (T == hip_f8_type::bf8) {
return (data == 0x7c) || (data == 0xfc);
} else {
return (data == 0x78) || (data == 0xf8);
}
}
}
};
#ifdef __HIPCC__
template<hip_f8_type T>
struct hip_f8x4 {
// define some convenience types
typedef float float32x2 __attribute__((ext_vector_type(2)));
typedef float float32x4 __attribute__((ext_vector_type(4)));
typedef _Float16 halfx2 __attribute__((ext_vector_type(2)));
typedef _Float16 halfx4 __attribute__((ext_vector_type(4)));
typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2)));
typedef uint16_t hip_bfloat16x4 __attribute__((ext_vector_type(4)));
uint32_t data;
// default constructor
HIP_HOST_DEVICE hip_f8x4() = default;
// constructor from bits
HIP_HOST_DEVICE hip_f8x4(uint32_t v);
// constructor from float
HIP_HOST_DEVICE hip_f8x4(float v0, float v1=0, float v2=0, float v3=0, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
HIP_HOST_DEVICE hip_f8x4(float32x2 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
HIP_HOST_DEVICE hip_f8x4(float32x4 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
// constructor from half
HIP_HOST_DEVICE hip_f8x4(half v0, half v1=0, half v2=0, half v3=0, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
HIP_HOST_DEVICE hip_f8x4(halfx2 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
HIP_HOST_DEVICE hip_f8x4(halfx4 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
// constructor from hip_bfloat16
HIP_HOST_DEVICE hip_f8x4(hip_bfloat16 v0, hip_bfloat16 v1=hip_bfloat16(0.0f), hip_bfloat16 v2=hip_bfloat16(0.0f), hip_bfloat16 v3=hip_bfloat16(0.0f), hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
HIP_HOST_DEVICE hip_f8x4(hip_bfloat16x2 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
HIP_HOST_DEVICE hip_f8x4(hip_bfloat16x4 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
// convert to float32x4
inline HIP_HOST_DEVICE operator float32x4() const;
// convert to halfx4
inline HIP_HOST_DEVICE operator halfx4() const;
// convert to hip_bfloat16x4
inline HIP_HOST_DEVICE operator hip_bfloat16x4() const;
};
template<hip_f8_type T>
struct hip_f8x8 {
// define some convenience types
typedef hip_f8x4<T> f8x8 __attribute__((ext_vector_type(2)));
f8x8 data;
// default constructor
HIP_HOST_DEVICE hip_f8x8() = default;
// do we need to define other constructors or any conversion routines here?
};
// If we do not end up needing either any constructors or conversion routines for the above type, then
// we can simplify the above type to the following
#if USE_SIMPLER_HIP_F8x8
template <hip_f8_type T>
using hip_f8x8 = hip_f8x4<T> __attribute__((ext_vector_type(2)));
#endif
typedef float hip_float32x4 __attribute__((ext_vector_type(4)));
typedef float hip_float32x16 __attribute__((ext_vector_type(16)));
// these are device-specific and we don't expect them to exist unless we're compiling with hip-clang for gfx942.
template<hip_f8_type T_A, hip_f8_type T_B>
__device__ hip_float32x4 mfma_f32_16x16x32(hip_f8x8<T_A> a, hip_f8x8<T_B> b, hip_float32x4 c);
template<hip_f8_type T_A, hip_f8_type T_B>
__device__ hip_float32x16 mfma_f32_32x32x16(hip_f8x8<T_A> a, hip_f8x8<T_B> b, hip_float32x16 c);
#endif //__HIPCC__
\ No newline at end of file
......@@ -107,8 +107,12 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
This is needed only for Hopper, which uses persistent CTA execution.
*/
int max_connection = transformer_engine::getenv<int>("CUDA_DEVICE_MAX_CONNECTIONS", 8);
#ifdef USE_ROCM
int runtime_version = 6;
#else
int runtime_version = 0;
cudaRuntimeGetVersion(&runtime_version);
#endif
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) {
......@@ -277,7 +281,11 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
assert(rs_output.size(0) == _ubuf.size(0) / _tp_size);
assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
#ifdef USE_ROCM
reducescatter2_userbuff_fp8<hip_f8<hip_f8_type::bf8>>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
#else
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
#endif
comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
} else {
......
......@@ -361,7 +361,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE));
NVTE_CHECK_CUDA(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags =
#ifdef USE_ROCM
reinterpret_cast<int *>((reinterpret_cast<uintptr_t>((*comm)->flags) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK
#else
reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
#endif
using namespace std;
......@@ -624,7 +628,12 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
} else {
#endif
if (alloc) {
#ifdef USE_ROCM
// Ref to RCCL
NVTE_CHECK_CUDA(hipExtMallocWithFlags(gpubuff, bytes, hipDeviceMallocFinegrained));
#else
NVTE_CHECK_CUDA(cudaMalloc(gpubuff, bytes));
#endif
NVTE_CHECK_CUDA(cudaMemset(*gpubuff, 0, bytes));
}
......
......@@ -9,8 +9,10 @@
#include <cuda_runtime.h>
#if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#define half_dtype nv_bfloat16
#else
#include <cuda_fp16.h>
#define half_dtype half
#endif
......@@ -24,6 +26,18 @@
#define MAX_THREADS 1024
#ifdef __HIP_PLATFORM_AMD__
#define ATOMIC_CONSUMER(chunk) \
if (counters) { \
if (threadIdx.x == 0 && blockIdx.x == 0) { \
while (0 != (atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \
} \
((unsigned int *)counters)[chunk] = 1; \
__threadfence_system(); \
} \
if (blockIdx.x == 0) __syncthreads(); \
}
#else
#define ATOMIC_CONSUMER(chunk) \
if (counters) { \
if (threadIdx.x == 0 && blockIdx.x == 0) { \
......@@ -34,6 +48,7 @@
} \
if (blockIdx.x == 0) __syncthreads(); \
}
#endif
#define ATOMIC_PRODUCER(chunk) \
if (counters) { \
......@@ -1025,7 +1040,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
// reset counter for next producer.
((unsigned int *)counters)[0] = 1;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system();
// __threadfence();
// __syncthreads()
#else
asm volatile("fence.sc.gpu;\n");
#endif
}
}
__syncthreads();
......@@ -1116,7 +1137,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
// reset counter for next producer.
((unsigned int *)counters)[chunk_i] = 1;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system();
// __threadfence();
// __syncthreads()
#else
asm volatile("fence.sc.gpu;\n");
#endif
}
}
__syncthreads();
......@@ -1357,6 +1384,33 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
} // fp16 inplace allgather kernel (Volta,Hopper)
#ifdef __HIP_PLATFORM_AMD__
#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[2]; \
attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \
attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \
attribute_ub[1].val.clusterDim.y = 1; \
attribute_ub[1].val.clusterDim.z = 1; \
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = 1; // dtk unsupport hipLaunchAttributeClusterDimension
#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event)
#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2
#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \
ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \
attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \
attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \
attribute_ub[1].val.clusterDim.y = 1; \
attribute_ub[1].val.clusterDim.z = 1; \
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = 1; // dtk unsupport hipLaunchAttributeClusterDimension
#else
#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[2]; \
......@@ -1389,6 +1443,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH;
#endif
#define callranks_ag(x) \
if (ar_nvsize == x) { \
......@@ -1932,6 +1987,53 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
}
}
#ifdef __HIP_PLATFORM_AMD__
template void reducescatter2_userbuff_stridedoutput_fp8<hip_f8<hip_f8_type::bf8>>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_stridedoutput_fp8<hip_f8<hip_f8_type::fp8>>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);
template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
reducescatter2_userbuff_stridedoutput_fp8<fp8type>(output, scale, handler, offset, elements, 1, 0,
comm, stream, comm_launch_event);
}
template void reducescatter2_userbuff_fp8<hip_f8<hip_f8_type::bf8>>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_fp8<hip_f8<hip_f8_type::fp8>>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_strided_atomic_fp8<hip_f8<hip_f8_type::fp8>>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_atomic_fp8<hip_f8<hip_f8_type::bf8>>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<hip_f8<hip_f8_type::fp8>>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<hip_f8<hip_f8_type::bf8>>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
#else
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
......@@ -1977,6 +2079,7 @@ template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
#endif
__global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) {
atomicAdd_system(flagptr, 1);
......@@ -2196,7 +2299,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
// Decrement atomic val to signal current output tile finish
if (counters) {
((unsigned int *)counters)[0] = 0;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system();
// __threadfence();
// __syncthreads()
#else
asm volatile("fence.sc.gpu;\n");
#endif
}
}
}
......@@ -2267,7 +2376,13 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat
// Decrement atomic val to signal current output tile finish
if (counters) {
((unsigned int *)counters)[recv_chunk_id /*chunk_i+1*/] = 0;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system();
// __threadfence();
// __syncthreads()
#else
asm volatile("fence.sc.gpu;\n");
#endif
}
}
......@@ -2545,7 +2660,13 @@ static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) {
// COMM kernel need to explicitely flash gmem.
// GEMM kernel already executed, and can not see gmem
// change without COMM kernel explicitely make change
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system();
// __threadfence();
// __syncthreads()
#else
asm volatile("fence.sc.gpu;\n");
#endif
}
// consumer
......@@ -2555,7 +2676,13 @@ static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) {
while (0 != (atomicCAS((unsigned int *)atomic_ptr + chunk_i, 0, 0))) {
}
((unsigned int *)atomic_ptr)[chunk_i] = 1;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system();
// __threadfence();
// __syncthreads()
#else
asm volatile("fence.sc.gpu;\n");
#endif
}
}
......@@ -2567,7 +2694,13 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i
while (0 != (atomicCAS((unsigned int *)atomic_ptr + i, 0, 0))) {
}
((unsigned int *)atomic_ptr)[i] = 1;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system();
// __threadfence();
// __syncthreads()
#else
asm volatile("fence.sc.gpu;\n");
#endif
}
}
}
......@@ -2661,12 +2794,21 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
num_aligned_elements_per_input, tot_input_size);
}
#ifdef __HIP_PLATFORM_AMD__
template void reduce_fp8_in_bf16_out<hip_f8<hip_f8_type::fp8>>(void *inputs, void *output, float *scale,
int num_inputs, int input_size,
cudaStream_t stream);
template void reduce_fp8_in_bf16_out<hip_f8<hip_f8_type::bf8>>(void *inputs, void *output, float *scale,
int num_inputs, int input_size,
cudaStream_t stream);
#else
template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale,
int num_inputs, int input_size,
cudaStream_t stream);
template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, float *scale,
int num_inputs, int input_size,
cudaStream_t stream);
#endif
template <int nvec>
__global__ void __launch_bounds__(MAX_THREADS / 4)
......
......@@ -105,14 +105,22 @@ struct communicator {
int memflags[NVTE_MAX_REGIONS]; // UC,MC, user/lib allocated
#ifdef __HIP_PLATFORM_AMD__
hipMemGenericAllocationHandle_t *uchandles[NVTE_MAX_REGIONS];
#else
CUmemGenericAllocationHandle *uchandles[NVTE_MAX_REGIONS];
#endif
void *ucbase_ptr[NVTE_MAX_REGIONS]; // only for cuMem allocated memory
size_t mem_size[NVTE_MAX_REGIONS];
bool mem_dealloc[NVTE_MAX_REGIONS];
void *mc_ptr[NVTE_MAX_REGIONS];
void *mc_baseptr;
#ifdef __HIP_PLATFORM_AMD__
hipMemGenericAllocationHandle_t mc_handle;
#else
CUmemGenericAllocationHandle mc_handle;
#endif
size_t mc_offset, mc_maxsize;
int use_mc; // 1: use MC if available, 0: override not to use MC
......
......@@ -36,6 +36,9 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
}
void checkCuDriverContext(CUstream stream) {
#ifdef __HIP_PLATFORM_AMD__
return;
#else
CUcontext ctx;
const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx);
switch (driver_status) {
......@@ -54,8 +57,10 @@ void checkCuDriverContext(CUstream stream) {
cuda_driver::call("cuGetErrorString", driver_status, &desc_NVTE_CHECK_CUDA_DRIVER);
NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER);
}
#endif
}
#ifndef __HIP_PLATFORM_AMD__
CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
static const std::unordered_map<DType, CUtensorMapDataType> dtypeMapping = {
{DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
......@@ -127,11 +132,16 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
// Any element that is outside of bounds will be set to zero by the TMA transfer.
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
}
#endif
bool is_supported_by_CC_100() {
#ifdef __HIP_PLATFORM_AMD__
return false;
#else
int deviceComputeCapability = cuda::sm_arch(cuda::current_device());
return deviceComputeCapability >= 100;
#endif
}
} // namespace transformer_engine
......@@ -6,8 +6,9 @@
#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
......@@ -217,9 +218,15 @@ using int32 = int32_t;
using int64 = int64_t;
using fp32 = float;
using fp16 = half;
#ifndef __HIP_PLATFORM_AMD__
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
#else
using bf16 = hip_bfloat16;
using fp8e4m3 = hip_f8<hip_f8_type::fp8>;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>;
#endif
#if CUDA_VERSION >= 12080
using fp8e8m0 = __nv_fp8_e8m0;
#endif
......@@ -239,9 +246,15 @@ TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
TRANSFORMER_ENGINE_TYPE_NAME(float)
TRANSFORMER_ENGINE_TYPE_NAME(half)
#ifdef __HIP_PLATFORM_AMD__
TRANSFORMER_ENGINE_TYPE_NAME(hip_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(hip_f8<hip_f8_type::fp8>)
TRANSFORMER_ENGINE_TYPE_NAME(hip_f8<hip_f8_type::bf8>)
#else
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
#endif
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif
......@@ -510,6 +523,7 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream);
void checkCuDriverContext(CUstream stream);
#ifndef __HIP_PLATFORM_AMD__
CUtensorMapDataType get_CUtensorMapDataType(DType dtype);
// Set up parameters to create TMA descriptor.
......@@ -517,6 +531,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_size);
#endif
bool is_supported_by_CC_100();
......
......@@ -11,6 +11,7 @@
namespace transformer_engine {
#ifndef USE_ROCM
// get cuDNN data type
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
......@@ -60,6 +61,7 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t)
void nvte_cudnn_handle_init() {
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
}
#endif
} // namespace transformer_engine
......
......@@ -7,9 +7,11 @@
#ifndef TRANSFORMER_ENGINE_CUDNN_UTILS_H_
#define TRANSFORMER_ENGINE_CUDNN_UTILS_H_
#ifndef __HIP_PLATFORM_AMD__
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#endif
#include <cstdint>
#include <mutex>
......@@ -17,7 +19,8 @@
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
#ifndef __HIP_PLATFORM_AMD__
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
......@@ -40,6 +43,7 @@ class cudnnExecutionPlanManager {
private:
cudnnHandle_t handle_ = nullptr;
};
#endif
} // namespace transformer_engine
......
......@@ -4,9 +4,15 @@
* See LICENSE for license information.
************************************************************************/
#ifndef __HIP_PLATFORM_AMD__
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda.h>
#else
#include <iostream>
#include "hipblas_gemm.h"
#include "rocm_gemm.hip"
#endif // #ifndef __HIP_PLATFORM_AMD__
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
......@@ -17,6 +23,7 @@
#include "../util/logging.h"
#include "common/util/cuda_runtime.h"
#ifndef __HIP_PLATFORM_AMD__
namespace {
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
......@@ -137,9 +144,19 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
}
} // namespace
#endif // __HIP_PLATFORM_AMD__
namespace transformer_engine {
#ifdef __HIP_PLATFORM_AMD__
//Forward declaration. The implementation is in rocm_gemm.cu
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, bool transa, bool transb, bool grad,
void* workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, int m_split, int n_split, bool gemm_producer,
const Tensor *inputCounter, hipStream_t stream);
#else // Use cublasLt
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad,
......@@ -424,6 +441,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
}
#endif // __HIP_PLATFORM_AMD__
static std::once_flag init_flag;
static cudaStream_t compute_streams[num_streams];
......@@ -437,6 +455,19 @@ static void init_streams_and_events() {
}
}
// Add for batchgemm
static std::once_flag init_flag_batchgemm;
static cudaStream_t compute_streams_batchgemm[num_batchgemm_streams];
static cudaEvent_t cublas_event_batchgemm[num_batchgemm_streams];
// Warning: only call once per device!
static void init_streams_and_events_batchgemm() {
for (int i = 0; i < num_batchgemm_streams; i++) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams_batchgemm[i], cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event_batchgemm[i]));
}
}
} // namespace transformer_engine
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
......@@ -477,10 +508,57 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTE_ERROR("TT layout not allowed.");
}
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad,
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
const char *NVTE_BLASLT_BLAS = std::getenv("NVTE_FORCE_BLASLT");
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_BLASLT_BLAS != nullptr && NVTE_BLASLT_BLAS[0] == '1')){
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#endif //USE_HIPBLASLT
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#endif //__HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
transa, transb,
#else
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
#endif //__HIP_PLATFORM_AMD__
grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, 0, 0, false, nullptr, stream);
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
}
else{
hipblas_gemm(inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
0,
0,
false,
nullptr,
stream);
}
#endif //USE_HIPBLASLT
#endif //__HIP_PLATFORM_AMD__
}
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
......@@ -491,10 +569,12 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_atomic_gemm);
#ifndef __HIP_PLATFORM_AMD__
int cudart_version;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm.");
NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm.");
#endif
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
......@@ -529,10 +609,103 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
NVTE_ERROR("TT layout not allowed.");
}
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
const char *NVTE_BLASLT_BLAS = std::getenv("NVTE_FORCE_BLASLT");
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_BLASLT_BLAS != nullptr && NVTE_BLASLT_BLAS[0] == '1')){
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#endif //USE_HIPBLASLT
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad,
#endif //__HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
transa, transb,
#else
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
#endif //__HIP_PLATFORM_AMD__
grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
}
else{
hipblas_gemm(
inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
m_split,
n_split,
gemm_producer,
inputCounter,
stream);
}
#endif //USE_HIPBLASLT
#endif //__HIP_PLATFORM_AMD__
}
void nvte_cublaslt_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublaslt_gemm);
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
Tensor *outputD = reinterpret_cast<Tensor *>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#ifdef __HIP_PLATFORM_AMD__
transa, transb,
#else
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
#endif //__HIP_PLATFORM_AMD__
grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, 0, 0, false, nullptr, stream);
}
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
......@@ -552,11 +725,29 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
}
for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM");
const char *NVTE_BLASLT_BLAS = std::getenv("NVTE_FORCE_BLASLT");
bool NVTE_FORCE_BLASLT_MULSTREAM;
if(NVTE_BLAS_MULSTREAM==nullptr){
NVTE_FORCE_BLASLT_MULSTREAM = true;
} elif((NVTE_BLASLT_BLAS != nullptr && NVTE_BLASLT_BLAS[0] == '1') && (NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1')){
NVTE_ERROR("NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_BLASLT can't be set at the same time.");
} else{
NVTE_FORCE_BLASLT_MULSTREAM = false;
}
if (NVTE_FORCE_BLASLT_MULSTREAM){
for (int i = 0; i < num_gemms; i++) {
nvte_cublaslt_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams]);
}
} else{
for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams]);
}
}
// record events on compute streams
......@@ -568,3 +759,107 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
}
}
#ifdef __HIP_PLATFORM_AMD__
void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
NVTETensor *workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm);
using namespace transformer_engine;
static_assert(num_gemms % num_batchgemm_streams == 0,
"Need num_gemms mod num_batchgemm_streams == 0.");
static int batch_count = num_gemms / num_batchgemm_streams;
// Inits streams and events (once, globally)
std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm);
int num_stream_used = num_batchgemm_streams;
// wait for current stream to finish
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[0], stream));
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams_batchgemm[s], cublas_event_batchgemm[0]));
}
for (int i = 0; i < num_stream_used; i++) {
nvte_cublas_batchgemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_batchgemm_streams], accumulate, use_split_accumulator, math_sm_count,
batch_count, compute_streams_batchgemm[i % num_batchgemm_streams]);
}
// record events on compute streams
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[s], compute_streams_batchgemm[s]));
}
// wait for all compute streams to finish
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event_batchgemm[s]));
}
}
// add for batchgemm
void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_batchgemm);
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
Tensor *outputD = reinterpret_cast<Tensor *>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
int m, n, k;
if (!transa && transb) {
// for NT
m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
} else if(transa && !transb){
// for TN
m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
} else if(!transa && !transb){
// for NN
m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
hipblas_batchgemm(inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
0,
0,
false,
nullptr,
batch_count,
stream);
}
#endif
\ No newline at end of file
/*************************************************************************
* Copyright (c) 2022-2024, S3000 qianyj. All rights reserved.
************************************************************************/
#include <hip/hip_runtime.h>
#include "hipblas_gemm.h"
#include "../common_hip.h"
#include "../util/logging.h"
namespace {
hipblasDatatype_t get_hip_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kFloat16:
return HIPBLAS_R_16F;
case DType::kFloat32:
return HIPBLAS_R_32F;
case DType::kBFloat16:
return HIPBLAS_R_16B;
default:
NVTE_ERROR("Invalid type");
}
}
} // namespace
// Define a static handle manager
static HipblasHandleManager handleManager;
namespace transformer_engine {
void hipblas_gemm(const Tensor *inputA,
const Tensor *inputB,
Tensor *outputD,
const Tensor *inputBias,
Tensor *outputPreGelu,
int m, int n, int k,
int lda, int ldb, int ldd,
hipblasOperation_t transa,
hipblasOperation_t transb,
bool grad,
void* workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
hipStream_t stream) {
// Use static handles
int device_id;
hipGetDevice(&device_id);
hipblasHandle_t handle = handleManager.get(device_id);
void *A = inputA->data.dptr;
// void *A_scale_inverse = inputA->scale_inv.dptr;
void *B = inputB->data.dptr;
// void *B_scale_inverse = inputB->scale_inv.dptr;
void *C = outputD->data.dptr;
void *D = outputD->data.dptr;
// Select the calculation accuracy
hipblasDatatype_t A_type = get_hip_dtype(inputA->data.dtype);
hipblasDatatype_t B_type = get_hip_dtype(inputB->data.dtype);
hipblasDatatype_t D_type = get_hip_dtype(outputD->data.dtype);
hipblasDatatype_t computeType = HIPBLAS_R_32F; // default acc is float32
// setting computetype
// if (/* condition for mixed precision */) {
// computeType = HIPBLAS_R_16F; //
// }
// hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F;
// const char *env_tf32 = std::getenv("NVTE_BLASLT_TF32");
// if (env_tf32 != nullptr && env_tf32[0] == '1') {
// if (A_type == HIPBLAS_R_32F && B_type == HIPBLAS_R_32F && D_type == HIPBLAS_R_32F) {
// gemm_compute_type = HIPBLAS_COMPUTE_32F_FAST_TF32;
// }
float one = 1.0f;
float zero = 0.0f;
float beta = accumulate ? one : zero;
hipblasSetStream(handle, stream);
// execute multiply
hipblasStatus_t status = hipblasGemmEx(
handle,
transa, // transa
transb, // transb
m,
n,
k,
static_cast<const void*>(&one),
A,
A_type,
lda,
B,
B_type,
ldb,
static_cast<const void*>(&beta),
D,
D_type,
ldd,
computeType,
HIPBLAS_GEMM_DEFAULT);
if (status != HIPBLAS_STATUS_SUCCESS) {
NVTE_ERROR("hipblasGemmEx execution failed");
}
}
void hipblas_batchgemm(const Tensor *inputA,
const Tensor *inputB,
Tensor *outputD,
const Tensor *inputBias,
Tensor *outputPreGelu,
int m, int n, int k,
int lda, int ldb, int ldd,
hipblasOperation_t transa,
hipblasOperation_t transb,
bool grad,
void* workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
int batch_count,
hipStream_t stream) {
// Use static handles
int device_id;
hipGetDevice(&device_id);
hipblasHandle_t handle = handleManager.get(device_id);
void *A = inputA->data.dptr;
// void *A_scale_inverse = inputA->scale_inv.dptr;
void *B = inputB->data.dptr;
// void *B_scale_inverse = inputB->scale_inv.dptr;
void *C = outputD->data.dptr;
void *D = outputD->data.dptr;
// Select the calculation accuracy
hipblasDatatype_t A_type = get_hip_dtype(inputA->data.dtype);
hipblasDatatype_t B_type = get_hip_dtype(inputB->data.dtype);
hipblasDatatype_t D_type = get_hip_dtype(outputD->data.dtype);
hipblasDatatype_t computeType = HIPBLAS_R_32F; // default acc is float32
float one = 1.0f;
float zero = 0.0f;
float beta = accumulate ? one : zero;
hipblasSetStream(handle, stream);
// execute multiply
// calculate stride
const long long int strideA = m*k;
const long long int strideB = k*n;
const long long int strideD = m*n;
hipblasStatus_t status = hipblasGemmStridedBatchedEx(
handle,
transa, // transa
transb, // transb
m,
n,
k,
static_cast<const void*>(&one),
A,
A_type,
lda,
strideA,
B,
B_type,
ldb,
strideB,
static_cast<const void*>(&beta),
D,
D_type,
ldd,
strideD,
batch_count,
computeType,
HIPBLAS_GEMM_DEFAULT);
if (status != HIPBLAS_STATUS_SUCCESS) {
NVTE_ERROR("hipblasGemmEx execution failed");
}
}
} // namespace transformer_engine
\ No newline at end of file
/*************************************************************************
* Copyright (c) 2022-2024, S3000 qianyj. All rights reserved.
************************************************************************/
/*! \file hipblas_gemmn.h
* \brief Functions for blas instead blaslt in pure gemm
*/
#ifndef TRANSFORMER_ENGINE_COMMON_HIPBLAS_GEMM_H_
#define TRANSFORMER_ENGINE_COMMON_HIPBLAS_GEMM_H_
#include <hip/hip_runtime.h>
#ifdef USE_HIPBLASLT
#include <hipblas/hipblas.h>
#include <mutex>
#else
#include <rocblas/rocblas.h>
#endif
#include <stdexcept>
#include "../common_hip.h"
#include <iostream>
#ifdef USE_HIPBLASLT
class HipblasHandleManager {
public:
HipblasHandleManager() {}
~HipblasHandleManager() {
// Release all handles when the manager is destroyed
for (auto& device_pair : handles_map_) {
hipblasDestroy(device_pair.second); // Only one handle per device
}
}
// Get a handle for the given device (creates if necessary)
hipblasHandle_t get(int device_id) {
std::lock_guard<std::mutex> lock(mutex_);
// Check if the handle for this device exists
auto device_it = handles_map_.find(device_id);
if (device_it != handles_map_.end()) {
return device_it->second;
}
// Create a new handle for this device if it doesn't exist
hipblasHandle_t handle;
hipblasStatus_t status = hipblasCreate(&handle);
if (status != HIPBLAS_STATUS_SUCCESS) {
throw std::runtime_error("Failed to create HIPBLAS handle");
}
// Store the handle in the map for this device
handles_map_[device_id] = handle;
return handle;
}
private:
std::unordered_map<int, hipblasHandle_t> handles_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
};
namespace transformer_engine {
void hipblas_gemm(const Tensor *inputA,
const Tensor *inputB,
Tensor *outputD,
const Tensor *inputBias,
Tensor *outputPreGelu,
int m, int n, int k,
int lda, int ldb, int ldd,
hipblasOperation_t transa,
hipblasOperation_t transb,
bool grad,
void* workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
hipStream_t stream);
void hipblas_batchgemm(const Tensor *inputA,
const Tensor *inputB,
Tensor *outputD,
const Tensor *inputBias,
Tensor *outputPreGelu,
int m, int n, int k,
int lda, int ldb, int ldd,
hipblasOperation_t transa,
hipblasOperation_t transb,
bool grad,
void* workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
int batch_count,
hipStream_t stream);
}
#else
class HipblasHandleManager {
public:
HipblasHandleManager() : handle_(nullptr) {}
~HipblasHandleManager() {
// Release the handle in the destructor to ensure cleanup when it's no longer needed
if (handle_ != nullptr) {
rocblas_destroy_handle(handle_);
}
}
// Get a handle to make sure it's valid every time
rocblas_handle get() {
if (handle_ == nullptr) {
createHandle();
}
// Check whether the handle is created successfully
assert(handle_ != nullptr && "hipblasHandle should not be null after creation");
return handle_;
}
private:
rocblas_handle handle_;
//
void createHandle() {
// A private method that creates a handle
rocblas_status status = rocblas_create_handle(&handle_);
if (status != rocblas_status_success) {
// If initialization fails, an exception is thrown
throw std::runtime_error("Failed to create HIPBLAS handle");
}
}
// Copy construct and assignment operations are prohibited
HipblasHandleManager(const HipblasHandleManager&) = delete;
HipblasHandleManager& operator=(const HipblasHandleManager&) = delete;
};
#endif // #ifdef USE_HIPBLASLT
#endif // TRANSFORMER_ENGINE_COMMON_HIPBLAS_GEMM_H_
\ No newline at end of file
/*************************************************************************
* Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
#include <type_traits>
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#ifdef USE_HIPBLASLT
#include <unistd.h>
#include <vector>
#include <forward_list>
#include <mutex>
#include <unordered_map>
#include <sstream>
#include <fstream>
#include <chrono>
#include <optional>
#include <hipblaslt/hipblaslt.h>
#endif
#ifdef USE_ROCBLAS
#define ROCBLAS_BETA_FEATURES_API
#include <rocblas/rocblas.h>
#include <hipcub/hipcub.hpp>
#endif
#include <iostream>
#include <cstdlib>
#include <string>
#include <cstdint>
#include "../common.h"
#include "../util/vectorized_pointwise.h"
#include "../util/logging.h"
namespace {
#ifdef USE_HIPBLASLT
#if HIP_VERSION >= 60000000
typedef hipDataType hipblasltDatatype_t;
typedef hipblasComputeType_t hipblasLtComputeType_t;
#define HIPBLASLT_R_16F HIP_R_16F
#define HIPBLASLT_R_32F HIP_R_32F
#define HIPBLASLT_R_16B HIP_R_16BF
#define HIPBLASLT_R_8F_E4M3 HIP_R_8F_E4M3_FNUZ
#define HIPBLASLT_R_8F_E5M2 HIP_R_8F_E5M2_FNUZ
#define HIPBLASLT_COMPUTE_F32 HIPBLAS_COMPUTE_32F
#endif // #if HIP_VERSION >= 60000000
hipblasltDatatype_t get_hipblaslt_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kFloat16:
return HIPBLASLT_R_16F;
case DType::kFloat32:
return HIPBLASLT_R_32F;
case DType::kBFloat16:
return HIPBLASLT_R_16B;
case DType::kFloat8E4M3:
return HIPBLASLT_R_8F_E4M3;
case DType::kFloat8E5M2:
return HIPBLASLT_R_8F_E5M2;
default:
NVTE_ERROR("Invalid type");
}
}
#endif
#ifdef USE_ROCBLAS
rocblas_datatype get_rocblas_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kFloat16:
return rocblas_datatype_f16_r;
case DType::kFloat32:
return rocblas_datatype_f32_r;
case DType::kBFloat16:
return rocblas_datatype_bf16_r;
case DType::kFloat8E4M3:
return rocblas_datatype_f8_r;
case DType::kFloat8E5M2:
return rocblas_datatype_bf8_r;
default:
NVTE_ERROR("Invalid type");
}
}
#endif
} //namespace
namespace transformer_engine {
#ifdef USE_ROCBLAS
namespace detail {
struct Empty {};
__device__ inline fp32 identity(fp32 value, const Empty&) {
return value;
}
__inline__ __device__
float gelu(float x, const Empty&)
{
float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf;
}
__inline__ __device__
float gelu_forward(float x)
{
float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf;
}
template <typename T, int THREADS_PER_BLOCK>
__global__
void gelu_forward_kernel(const float* in, T* out, float* amax, const float* scale, int m, int n) {
// fp8 output flow
if constexpr(std::is_same<T, fp8e4m3>::value ||std::is_same<T, fp8e5m2>::value){
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0;
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x){
float x = in[id];
float y = gelu_forward(x);
out[id] = (T)((*scale)*y);
thread_amax=std::fmax(std::fabs(y), thread_amax);
}
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if(threadIdx.x==0){
atomicMaxFloat(amax, block_amax);
}
}else{
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x){
float x = in[id];
float y = gelu_forward(x);
out[id] = (T)(y);
}
}
}
template <typename T>
void gelu_forward_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int m, int n, hipStream_t stream) {
dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK;
grid.x = ceil(1.0*m * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL(( gelu_forward_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, amax, scale, m, n);
}
__inline__ __device__
float gelu_backward(float x, float dy){
constexpr float kBeta = 0.7978845608028654f;
constexpr float kKappa = 0.044715f;
float x_sq = x * x;
float x_cube = x_sq * x;
float tanh_inner = tanhf((kBeta * (x + kKappa * x_cube)));
float left = 0.5 * x;
float right = 1.0f + tanh_inner;
float left_derivative = 0.5 * right;
float tanh_derivative = 1 - tanh_inner * tanh_inner;
float inner_derivative = kBeta * (1.0f + 3.0 * kKappa * x_sq);
float right_derivative = left * tanh_derivative * inner_derivative;
return dy * (left_derivative + right_derivative);
}
template <typename T, typename Taux>
__global__
void gelu_backward_kernel(const float* dy, T* out, const Taux* __restrict pre_gelu_out, int m, int n) {
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x)
{
float x = (float)pre_gelu_out[id];
float dx = (float)gelu_backward(x, dy[id]);
out[id] = (T)(dx);
}
}
template <typename T, typename Taux>
void gelu_backward_kernelLauncher(const float* in, T* out, const Taux* pre_gelu_out, int m, int n, hipStream_t stream) {
int blocks_per_row = ceil(float(n)/1024);
dim3 grid(min(m * blocks_per_row, 65536));
dim3 block(min(n, 1024));
hipLaunchKernelGGL(( gelu_backward_kernel<T, Taux>), dim3(grid), dim3(block), 0, stream, in, out, pre_gelu_out, m, n);
}
template <typename T, typename Tb, int THREADS_PER_BLOCK>
__global__
void add_bias_kernel(const float* in, T* out, const Tb* __restrict bias, float* amax, const float* scale, int m, int n){
// fp8 output flow
if constexpr(std::is_same<T, fp8e4m3>::value ||std::is_same<T, fp8e5m2>::value){
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0;
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x){
float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias;
out[id] = (T)((*scale)*val);
// deal with amax of D
thread_amax=std::fmax(std::fabs(val), thread_amax);
}
// num_valid can be ignored since each thread amax is set to 0
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if(threadIdx.x==0){
atomicMaxFloat(amax, block_amax);
}
}else{
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x){
float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias;
out[id] = (T)(val);
}
}
}
template <typename T, typename Tb>
void add_bias_kernelLauncher(const float* in, T* out, const Tb* __restrict bias, float* amax, const float* scale, int m, int n, hipStream_t stream) {
dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK;
grid.x = ceil(1.0*m * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL(( add_bias_kernel<T, Tb, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, bias, amax, scale, m, n);
}
template <typename T, typename Taux, typename Tb, int THREADS_PER_BLOCK>
__global__
void add_bias_gelu_kernel(const float* in, T* out, Taux* pre_gelu_out, const Tb* __restrict bias, float* amax, const float* scale, int m, int n){
// fp8 output flow
if constexpr(std::is_same<T, fp8e4m3>::value ||std::is_same<T, fp8e5m2>::value){
// only need to deal with amax and scale of D, no need to deal with amax and scale of pre_gelu_out
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0;
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x){
float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias;
// pre_gelu_out guaranteed not to be fp8 type
pre_gelu_out[id] = (Taux)(val);
val = gelu_forward(val);
out[id] = (T)((*scale)*val);
// deal with amax of D
thread_amax=std::fmax(std::fabs(val), thread_amax);
}
// num_valid can be ignored since each thread amax is set to 0
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if(threadIdx.x==0){
atomicMaxFloat(amax, block_amax);
}
}else{
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x){
float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias;
pre_gelu_out[id] = (Taux)(val);
out[id] = (T)(gelu_forward(val));
}
}
}
template <typename T, typename Taux, typename Tb>
void add_bias_gelu_kernelLauncher(const float* in, T* out, Taux* pre_gelu_out, const Tb* __restrict bias, float* amax, const float* scale, int m, int n, hipStream_t stream) {
dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK;
grid.x = ceil(1.0*m * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL(( add_bias_gelu_kernel<T, Taux, Tb, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, pre_gelu_out, bias, amax, scale, m, n );
}
template <typename Tin, typename T>
__global__
void identity_kernel(const Tin* in, T* out, int n) {
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x)
{
Tin val = in[id];
out[id] = (T)(val);
}
}
template <typename Tin, typename T>
void identity_kernelLauncher(const Tin* in, T* out, int n, hipStream_t stream) {
dim3 block, grid;
block.x = 1024;
grid.x = ceil( n / 1024.);
hipLaunchKernelGGL(( identity_kernel<Tin, T>), dim3(grid), dim3(block), 0, stream, in, out, n );
}
template <typename T, int THREADS_PER_BLOCK>
__global__
void identity_output_kernel(const float* in, T* out, float* amax, const float* scale, int n) {
if constexpr(std::is_same<T, fp8e4m3>::value ||std::is_same<T, fp8e5m2>::value){
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0;
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x){
float val = in[id];
out[id] = (T)((*scale)*val);
// deal with amax of D
thread_amax=std::fmax(std::fabs(val), thread_amax);
}
// num_valid can be ignored since each thread amax is set to 0
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if(threadIdx.x==0){
atomicMaxFloat(amax, block_amax);
}
}else{
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x){
float val = in[id];
out[id] = (T)(val);
}
}
}
template <typename T>
void identity_output_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int n, hipStream_t stream) {
dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK;
grid.x = ceil( 1.0*n / THREADS_PER_BLOCK);
hipLaunchKernelGGL(( identity_output_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, amax, scale, n );
}
template <typename Tin, int THREADS_PER_BLOCK>
__global__
void bias_gradient_kernel(const Tin* in, float* out, int m, int n) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage;
int BLOCKS_PER_COL = ceil(float(m)/THREADS_PER_BLOCK);
int THREADS_PER_COL = BLOCKS_PER_COL * THREADS_PER_BLOCK;
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int col_idx = idx / THREADS_PER_COL;
int row_idx = idx % THREADS_PER_COL;
float thread_data;
if (row_idx < m)
thread_data = (float)in[row_idx * n + col_idx];
float local_sum;
if (row_idx < (BLOCKS_PER_COL-1) * THREADS_PER_BLOCK) {
local_sum = BlockReduce(block_temp_storage).Sum(thread_data);
}
else {
local_sum = BlockReduce(block_temp_storage).Sum(thread_data, m-(BLOCKS_PER_COL-1)*THREADS_PER_BLOCK);
}
if (threadIdx.x == 0)
atomicAdd(&out[col_idx], local_sum);
}
template <typename Tin>
void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool stream_order_alloc, hipStream_t stream) {
dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024;
int BLOCKS_PER_COL = ceil(float(m)/THREADS_PER_BLOCK);
block.x = THREADS_PER_BLOCK;
grid.x = BLOCKS_PER_COL*n;
if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMemset(out, 0, n*sizeof(float)) );
}else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMemsetAsync(out, 0, n*sizeof(float), stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
}
hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
}
} // namespace detail
transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t) {
using namespace transformer_engine;
switch (t) {
case rocblas_datatype_f16_r:
return DType::kFloat16;
case rocblas_datatype_f32_r:
return DType::kFloat32;
case rocblas_datatype_bf16_r:
return DType::kBFloat16;
case rocblas_datatype_f8_r:
return DType::kFloat8E4M3;
case rocblas_datatype_bf8_r:
return DType::kFloat8E5M2;
default:
NVTE_ERROR("Invalid type");
}
}
#endif //USE_ROCBLAS
#ifdef USE_HIPBLASLT
namespace {
static class HandlePool {
public:
hipblasLtHandle_t get(int device_id)
{
std::lock_guard<std::mutex> lock(mt);
if (pool.empty())
{
int device_count = 0;
NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count));
pool.resize(device_count);
return nullptr;
}
if (!pool[device_id].empty())
{
hipblasLtHandle_t h = pool[device_id].front();
pool[device_id].pop_front();
return h;
}
return nullptr;
}
hipblasLtHandle_t obtain(int device_id)
{
hipblasLtHandle_t h = get(device_id);
if (h == nullptr)
{
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&h));
}
return h;
}
void store(const std::vector<hipblasLtHandle_t>& handles)
{
std::lock_guard<std::mutex> lock(mt);
if (pool.empty())
{
std::cout << "[ERROR] Attempt to store handles to invalid pool" << std::endl;
}
for (unsigned int i=0; i<pool.size(); i++)
{
if (handles[i] != nullptr)
{
pool[i].push_front(handles[i]);
}
}
}
~HandlePool() {
#if DESTROY_HIPBLASLT_HANDLES_POOL
std::lock_guard<std::mutex> lock(mt);
for (auto & hlist : pool)
{
for (auto & h : hlist)
{
hipblasLtDestroy(h);
}
}
pool.clear();
#endif
}
inline size_t get_size() const
{
return pool.size();
}
private:
std::mutex mt;
using Pool = std::vector<std::forward_list<hipblasLtHandle_t>>;
// Order of destructors between thread_local and global is not actually guaranteed
// As a simple w/a make pool storage "leaky"
// Just do not destruct it and do not destroy hipbladLt handles
// Let OS deal with it on application exit
#if DESTROY_HIPBLASLT_HANDLES_POOL
Pool pool;
#else
Pool &pool = *new Pool();
#endif
} handle_pool;
thread_local static class HandleCache {
public:
hipblasLtHandle_t get(int device_id) const
{
return d.empty() ? nullptr : d[device_id];
}
hipblasLtHandle_t obtain(int device_id)
{
hipblasLtHandle_t h = get(device_id);
if (h)
{
return h;
}
h = handle_pool.obtain(device_id);
set(device_id, h);
return h;
}
void set(int device_id, hipblasLtHandle_t h)
{
if (d.empty())
{
d.resize(handle_pool.get_size());
}
d[device_id] = h;
}
~HandleCache()
{
if (!d.empty())
{
handle_pool.store(d);
}
}
private:
std::vector<hipblasLtHandle_t> d;
} cached_handles;
class csv_helper
{
public:
struct start {};
struct end {};
csv_helper(std::ostream& os, char sep_val) : m_os{ os }, m_sep_val(sep_val), m_start(true), m_sep("") {}
csv_helper& operator << (const start&)
{
m_start = true;
return *this;
}
csv_helper& operator << (const end&)
{
m_sep="";
m_start = false;
return *this;
}
template< typename T>
csv_helper& operator<<(const T& v)
{
m_os << m_sep << v;
if (m_start)
{
m_start = false;
m_sep = m_sep_val;
}
return *this;
}
private:
std::ostream& m_os;
char m_sep_val;
bool m_start;
std::string m_sep;
};
template<typename T>
class NameMapper
{
public:
NameMapper(const std::unordered_map<T, std::string_view>& name_map): map(name_map) {}
const std::string_view &getName(const T &val) {
return map.at(val);
}
T getValue(const std::string& name, const char *label="")
{
for (auto iter = map.begin(); iter != map.end(); ++iter)
{
if (name == iter->second) return iter->first;
}
NVTE_ERROR("Invalid ", label, " name: ", name);
}
protected:
const std::unordered_map<T, std::string_view> &map;
};
static std::unordered_map<hipblasltDatatype_t, std::string_view> type_name_map = {
{HIPBLASLT_R_32F, "float32"},
{HIPBLASLT_R_16F, "float16"},
{HIPBLASLT_R_16B, "bfloat16"},
{HIPBLASLT_R_8F_E4M3, "float8e4m3"},
{HIPBLASLT_R_8F_E5M2, "float8e5m2"},
};
static NameMapper<hipblasltDatatype_t> typeNameMapper(type_name_map);
static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = {
{HIPBLAS_OP_N, "N"},
{HIPBLAS_OP_T, "T"}
};
static NameMapper<hipblasOperation_t> transposeNameMapper(trans_name_map);
static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map = {
{HIPBLASLT_EPILOGUE_DEFAULT, "-"},
{HIPBLASLT_EPILOGUE_BIAS, "bias"},
{HIPBLASLT_EPILOGUE_GELU_AUX, "geluaux"},
{HIPBLASLT_EPILOGUE_GELU_AUX_BIAS, "geluauxbias"},
{HIPBLASLT_EPILOGUE_DGELU, "dgelu"},
{HIPBLASLT_EPILOGUE_DGELU_BGRAD, "dgelubgrad"},
{HIPBLASLT_EPILOGUE_BGRADB, "bgradb"}
};
static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map);
static std::unordered_map<hipblasLtComputeType_t, std::string_view> comp_name_map = {
{HIPBLASLT_COMPUTE_F32, "f32"}
};
static NameMapper<hipblasLtComputeType_t> computeNameMapper(comp_name_map);
static class GemmAlgoCache {
public:
struct Key {
int deviceCap;
hipblasltDatatype_t a_type, b_type, d_type, bias_type;
int m, n, k;
int lda, ldb, ldd;
hipblasOperation_t transa, transb;
hipblasLtEpilogue_t epilogue;
Key(int deviceCap_,
hipblasltDatatype_t a_type_, hipblasltDatatype_t b_type_,
hipblasltDatatype_t d_type_, hipblasltDatatype_t bias_type_,
int m_, int n_, int k_, int lda_, int ldb_, int ldd_,
hipblasOperation_t transa_, hipblasOperation_t transb_,
hipblasLtEpilogue_t epilogue_):
deviceCap(deviceCap_),
a_type(a_type_), b_type(b_type_),
d_type(d_type_), bias_type(bias_type_),
m(m_), n(n_), k(k_), lda(lda_), ldb(ldb_), ldd(ldd_),
transa(transa_), transb(transb_),
epilogue(epilogue_) {}
Key() {}
bool operator==(const Key &val) const
{
return ((deviceCap == val.deviceCap)
&& (a_type == val.a_type) && (b_type == val.b_type)
&& (d_type == val.d_type) && (bias_type == val.bias_type)
&& (m == val.m) && (n == val.n) && (k == val.k)
&& (lda == val.lda) && (ldb == val.ldb) && (ldd == val.ldd)
&& (transa == val.transa) && (transb == val.transb)
&& (epilogue == val.epilogue) );
}
struct Comp
{
bool operator()(const Key& lhs, const Key& rhs) const
{
return ::std::string_view((const char*)&lhs, sizeof(lhs)) < ::std::string_view((const char*)&rhs, sizeof(rhs));
}
};
};
void init()
{
std::lock_guard<std::mutex> lock(mt);
int device_count = 0;
NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count));
dev_cap.resize(device_count);
for (int i=0; i<device_count; i++)
{
hipDeviceProp_t prop;
NVTE_CHECK_CUDA(hipGetDeviceProperties(&prop, i));
dev_cap[i] = prop.major*100 + prop.minor;
}
load_();
save_();
}
inline int device_cap(int device_id)
{
if (dev_cap.empty())
init();
return dev_cap[device_id];
}
struct Algo {
std::optional<hipblasLtMatmulAlgo_t> algo;
int64_t algoId;
int index;
size_t ws_size_min;
size_t ws_size_max;
Algo(): algo(), index(-1), algoId(), ws_size_min(0), ws_size_max(0) {}
Algo(int idx, int64_t id, size_t ws_min, size_t ws_max): algo(), index(idx), algoId(id), ws_size_min(ws_min), ws_size_max(ws_max) {}
inline bool hasId() { return index>=0; } const
static inline int64_t getAlgoId(const hipblasLtMatmulAlgo_t &algo)
{
return *(const int64_t*)&algo;
}
};
bool find(const Key &cfg, size_t ws_size, Algo &algo)
{
std::lock_guard<std::mutex> lock(mt);
if (auto *pentry = find_(cfg, ws_size, ws_size); pentry != nullptr)
{
algo = *pentry;
return true;
}
return false;
}
void store(const Key &cfg, const Algo &algo)
{
size_t ws_size_min = algo.ws_size_min;
size_t ws_size_max = algo.ws_size_max;
NVTE_CHECK(ws_size_max >= ws_size_min, "Invalid WS size");
std::lock_guard<std::mutex> lock(mt);
//Remove overlapping with existing entries;
while (auto* pentry = find_(cfg, ws_size_min, ws_size_max)) {
if (pentry->ws_size_min <= ws_size_min && pentry->ws_size_max >= ws_size_max)
{
*pentry = algo;
save_();
return;
}
if (ws_size_max > pentry->ws_size_max)
{
ws_size_min = pentry->ws_size_max + 1;
}
else if (ws_size_min < pentry->ws_size_min)
{
ws_size_max = pentry->ws_size_min - 1;
}
else
{
//Should never be here
NVTE_ERROR("Cannot merge WS size range");
}
}
//Merge to adjusted entry if possible
auto* pentry = find_(cfg, ws_size_min - 1, ws_size_min);
if (pentry && pentry->algoId == algo.algoId)
{
pentry->algo = algo.algo;
pentry->ws_size_max = ws_size_max;
save_();
}
else
{
auto it = d.emplace(cfg, algo);
it->second.ws_size_min = ws_size_min;
it->second.ws_size_max = ws_size_max;
save_(it->first, it->second);
}
}
protected:
Algo* find_(const Key &cfg, size_t ws_min, size_t ws_max)
{
const auto key_range = d.equal_range(cfg);
for (auto i = key_range.first; i != key_range.second; i++)
{
if (ws_min <= i->second.ws_size_max && ws_max >= i->second.ws_size_min)
{
return &i->second;
}
}
return nullptr;
}
void header_(std::ostream& ofs)
{
csv_helper fs(ofs, csv_sep);
fs << "dev_cap" << "m" << "n" << "k" << "trans_a" << "trans_b"
<< "type_a" << "type_b" << "type_d" << "bias_type"
<< "lda" << "ldb" << "ldd" << "epi" << "comp" << "scale"
<< "ws_min" << "ws_max" << "algo_id" << "aidx";
}
void load_()
{
const char* env = std::getenv("TE_HIPBLASLT_ALGO_LOAD");
if (env == nullptr || env[0] == '\0')
{
return;
}
std::ifstream ifs{env};
if (!ifs.is_open())
{
std::cerr << "Could not load autotune results storage " << env << "\n";
return;
}
std::cout << "Loading autotune results from " << env << "\n";
Key cfg;
std::string line;
std::getline(ifs, line); // the first line with legend
{
std::ostringstream hline;
header_(hline);
if (hline.str() != line) {
std::cerr << "Incorrect algo storage legend. Expected " << hline.str() << "\n";
return;
}
}
while(std::getline(ifs, line))
{
line.erase(0, line.find_first_not_of(" \t\n\r\f\v"));
if (auto pos = line.find_last_not_of(" \t\n\r\f\v"); pos != std::string::npos)
{
line.resize(pos+1);
}
if (line.empty() || line[0] == '#') continue;
std::istringstream is(line);
char c;
std::string type_a, type_b, type_d, bias_type, trans_a, trans_b, epi, comp, scale;
int64_t algo_id;
int algo_idx;
size_t ws_min, ws_max;
is >> std::skipws;
is >> cfg.deviceCap >> c >> cfg.m >> c >> cfg.n >> c >> cfg.k >> c;
//Filter out entries for devices not presented on the curent system
bool b_found = false;
for (int i=0; i<dev_cap.size(); i++)
{
if (dev_cap[i] == cfg.deviceCap)
{
b_found = true;
break;
}
}
if (!b_found) continue;
std::getline(is, trans_a, csv_sep);
std::getline(is, trans_b, csv_sep);
std::getline(is, type_a, csv_sep);
std::getline(is, type_b, csv_sep);
std::getline(is, type_d, csv_sep);
std::getline(is, bias_type, csv_sep);
is >> cfg.lda >> c >> cfg.ldb >> c >> cfg.ldd >> c;
std::getline(is, epi, csv_sep);
std::getline(is, comp, csv_sep);
std::getline(is, scale, csv_sep);
is >> ws_min >> c >> ws_max >> c >> algo_id >> c >> algo_idx;
if (is.bad())
{
std::cerr << "Parsing CSV line failed: " << line << "\n";
return;
}
if (ws_min > ws_max)
{
std::cout << "[WARNING] Invalid WS size at " << line << "\n";
continue;
}
cfg.a_type = typeNameMapper.getValue(type_a, "type_a");
cfg.b_type = typeNameMapper.getValue(type_b, "type_b");
cfg.d_type = typeNameMapper.getValue(type_d, "type_d");
cfg.bias_type = (bias_type == "-") ? (hipblasltDatatype_t)-1 : typeNameMapper.getValue(bias_type, "bias_type");
cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a");
cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b");
cfg.epilogue = epilogueNameMapper.getValue(epi, "epi");
//Check and filter out compute and scale types
if (computeNameMapper.getValue(comp, "comp") != HIPBLASLT_COMPUTE_F32 || typeNameMapper.getValue(scale, "scale") != HIPBLASLT_R_32F)
{
continue;
}
if (find_(cfg, ws_min, ws_max))
{
std::cout << "[WARNING] Duplicated/overlapped entry in algo cache\n";
continue;
}
d.emplace(cfg, Algo(algo_idx, algo_id, ws_min, ws_max));
}
}
bool can_save_(bool reopen = false)
{
if (!save_fs)
{
const char* temp = std::getenv("TE_HIPBLASLT_ALGO_SAVE");
if (temp == nullptr || temp[0] == '\0')
{
return false;
}
save_fs_name = temp;
pid_t pid = getpid();
size_t pos = 0;
while ((pos = save_fs_name.find("%i", pos)) != std::string::npos) {
save_fs_name.replace(pos, 2, std::to_string(pid));
}
save_fs = std::make_unique<std::ofstream>();
std::cout << "Saving autotune results to " << save_fs_name << "\n";
}
if (reopen)
{
if (save_fs->is_open())
{
save_fs->close();
}
save_fs->open(save_fs_name, std::ios_base::trunc);
}
if (save_fs->is_open() && !save_fs->bad())
{
return true;
}
else
{
if (reopen) std::cerr << "Could not open autotune results storage " << save_fs_name << "\n";
return false;
}
}
void save_()
{
if (!can_save_(true))
{
return;
}
header_(*save_fs);
*save_fs << "\n";
for (const auto &elem: d)
{
save_(elem.first, elem.second);
}
}
void save_(const Key &cfg, const Algo &algo)
{
if (!can_save_())
{
return;
}
csv_helper csv(*save_fs, csv_sep);
csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k
<< transposeNameMapper.getName(cfg.transa) << transposeNameMapper.getName(cfg.transb)
<< typeNameMapper.getName(cfg.a_type) << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type)
<< ((cfg.bias_type == (hipblasltDatatype_t)-1) ? "-" : typeNameMapper.getName(cfg.bias_type))
<< cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue)
<< computeNameMapper.getName(HIPBLASLT_COMPUTE_F32) << typeNameMapper.getName(HIPBLASLT_R_32F)
<< algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << "\n";
}
private:
std::vector<int> dev_cap;
constexpr static char csv_sep = ',';
std::unique_ptr<std::ofstream> save_fs;
std::string save_fs_name;
std::mutex mt;
/* Map of problem config to tuple of ws_size and Algo
* When searching, elements matching Key are filtered
* for requested WS size be between Algo.ws_size and pair.first
*/
std::multimap<Key, Algo, Key::Comp> d;
} algoCache;
static inline int getIntEnv(const char *name, int defval, int minval)
{
int val = defval;
const char* env = std::getenv(name);
if (env != nullptr && env[0] != '\0')
{
val = atoi(env);
if (val < minval)
{
val = minval;
}
}
return val;
}
} //namespace
void hipblaslt_gemm(const Tensor *inputA,
const Tensor *inputB,
Tensor *outputD,
const Tensor *inputBias,
Tensor *outputPreGelu,
int m, int n, int k,
int lda, int ldb, int ldd,
hipblasOperation_t transa,
hipblasOperation_t transb,
bool grad,
void* workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
hipStream_t stream
) {
void *A = inputA->data.dptr;
void *A_scale_inverse = inputA->scale_inv.dptr;
void *B = inputB->data.dptr;
void *B_scale_inverse = inputB->scale_inv.dptr;
void *D = outputD->data.dptr;
void *bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr;
void *pre_gelu_out = outputPreGelu->data.dptr;
const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
const hipblasltDatatype_t A_type = get_hipblaslt_dtype(inputA->data.dtype);
const hipblasltDatatype_t B_type = get_hipblaslt_dtype(inputB->data.dtype);
const hipblasltDatatype_t D_type = get_hipblaslt_dtype(outputD->data.dtype);
const hipblasltDatatype_t bias_type = get_hipblaslt_dtype(inputBias->data.dtype);
NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!");
// check consistency of arguments:
// if fp8 is desired, context cannot be null
// fp8 + gelu fusion + fp8 aux is unavailable right now.
if (use_fp8) {
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
}
float one = 1.0;
float zero = 0.0;
float beta = (accumulate) ? one : zero;
int device_id;
NVTE_CHECK_CUDA(hipGetDevice(&device_id));
hipblasLtHandle_t handle = cached_handles.get(device_id);
if (handle == nullptr)
{
handle = cached_handles.obtain(device_id);
}
hipblasLtMatmulDesc_t operationDesc = nullptr;
hipblasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
hipblasLtMatmulPreference_t preference = nullptr;
hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
int64_t ld_gelumat = (int64_t) ldd;
// default to tf32 except for e5m2 inputs where the config is not supported
hipblasLtComputeType_t gemm_compute_type = HIPBLASLT_COMPUTE_F32;
// Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type,
transa == HIPBLAS_OP_N ? m : k,
transa == HIPBLAS_OP_N ? k : m,
lda));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Bdesc, B_type,
transb == HIPBLAS_OP_N ? k : n,
transb == HIPBLAS_OP_N ? n : k,
ldb));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIPBLASLT_R_32F));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA,
&transa, sizeof(transa)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB,
&transb, sizeof(transb)));
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate
// Note: gelu fusion isn't available right now, and we don't need
// amax(D) either (next op is high precision).
if (use_fp8) {
// Split accumulator.
const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
/*
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_FAST_ACCUM, //TODO: We don't have fast accum mode yet
&fastAccuMode,
sizeof(fastAccuMode)));
*/
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse,
sizeof(A_scale_inverse)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse,
sizeof(B_scale_inverse)));
if (bias) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE,
&bias_type, sizeof(bias_type)));
}
}
if (bias && gelu) {
if (grad) {
epilogue = HIPBLASLT_EPILOGUE_DGELU_BGRAD;
} else {
epilogue = HIPBLASLT_EPILOGUE_GELU_AUX_BIAS;
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr, sizeof(bias_ptr)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&ld_gelumat, sizeof(ld_gelumat)));
} else if (bias) {
if (grad) {
// grad output is always input B
epilogue = HIPBLASLT_EPILOGUE_BGRADB;
} else {
epilogue = HIPBLASLT_EPILOGUE_BIAS;
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr, sizeof(bias_ptr)));
} else if (gelu) {
if (grad) {
epilogue = HIPBLASLT_EPILOGUE_DGELU;
} else {
epilogue = HIPBLASLT_EPILOGUE_GELU_AUX;
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&ld_gelumat, sizeof(ld_gelumat)));
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue)));
GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type,
use_fp8 ? bias_type : (hipblasltDatatype_t)-1,
m, n, k, lda, ldb, ldd, transa, transb, epilogue );
GemmAlgoCache::Algo cached_algo;
if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value())
{
int firstAlgo = getIntEnv("TE_HIPBLASLT_ALGO_SELECTION", 0, 0);
int tuneLoopCount = getIntEnv("TE_HIPBLASLT_TUNING_RUN_COUNT", 0, 0);
int algoTuneCount = 1;
std::vector<hipblasLtMatmulHeuristicResult_t> algoArr;
bool logTuning = getIntEnv("TE_HIPBLASLT_LOG_TUNING", 0, 0) != 0;
if (tuneLoopCount)
{
/* HIPBLASLT may return hundreds of algos for some configs
* Limit amount by default. User may override with env
*/
static const int defaultAlgoCount = 16;
algoTuneCount = getIntEnv("TE_HIPBLASLT_TUNING_ALGO_COUNT", defaultAlgoCount, 1);
}
algoTuneCount += firstAlgo;
int algoTotalCount = cached_algo.hasId() ? std::max(algoTuneCount, (cached_algo.index + 1)) : algoTuneCount;
algoArr.resize(algoTotalCount);
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceSetAttribute(
preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize, sizeof(workspaceSize)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Ddesc,
Ddesc, preference, algoTotalCount, algoArr.data(),
&algoTotalCount));
algoArr.resize(algoTotalCount);
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceDestroy(preference));
//If cached algo exists in persistent storage we just need to find matching hipblasLtMatmulAlgo_t
if (cached_algo.hasId())
{
int idx = (cached_algo.index < algoTotalCount) ? cached_algo.index : 0;
for (int i=0; i<algoTotalCount; i++)
{
const auto &algo = algoArr[idx];
if (algo.state == HIPBLAS_STATUS_SUCCESS)
{
if (cached_algo.algoId == cached_algo.getAlgoId(algo.algo))
{
cached_algo.algo = algo.algo;
if (algo.workspaceSize != cached_algo.ws_size_min || idx != cached_algo.index)
{
cached_algo.ws_size_min = algo.workspaceSize;
cached_algo.index = idx;
algoCache.store(gemm_cfg, cached_algo);
}
break;
}
}
idx = (idx + 1) % algoTotalCount;
}
if (logTuning && !cached_algo.algo.has_value())
{
std::cout << "[WARNING] Cannot find cached algoId " << cached_algo.algoId << " in hipBLASLt results" << std::endl;
}
}
//No suitable entry in autotune cache or could not find matched algo in hipBLASLt results
if (!cached_algo.algo.has_value())
{
int bestAlgo = -1;
algoTuneCount = std::min(algoTuneCount, algoTotalCount);
if (tuneLoopCount > 0)
{
if (logTuning)
std::cout << "[INFO] Perform hipBLASLt algo selection on GPU" << device_id
<< " in range [" << firstAlgo << "-" << (algoTuneCount - 1) << "] with "
<< tuneLoopCount << " loops " << std::endl;
hipStream_t profilingStream;
NVTE_CHECK_CUDA(hipStreamCreateWithFlags(&profilingStream, hipStreamNonBlocking));
using tuning_clock = std::chrono::steady_clock;
tuning_clock::now(); //the first call takes little longer so do it outside the loop
tuning_clock::duration bestTime = tuning_clock::duration::max();
for (int algo=firstAlgo; algo<algoTuneCount; algo++)
{
if (algoArr[algo].state != HIPBLAS_STATUS_SUCCESS)
{
continue;
}
// Warm-up call
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle,
operationDesc,
static_cast<const void*>(&one), /* alpha */
A, /* A */
Adesc,
B, /* B */
Bdesc,
static_cast<const void*>(&beta), /* beta */
D, /* C */
Ddesc,
D, /* D */
Ddesc,
&algoArr[algo].algo, /* algo */
workspace, /* workspace */
workspaceSize,
profilingStream)); /* stream */
NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));
//Profiling loop
tuning_clock::time_point startTime = tuning_clock::now();
for (int loop=0; loop<tuneLoopCount; loop++)
{
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle,
operationDesc,
static_cast<const void*>(&one), /* alpha */
A, /* A */
Adesc,
B, /* B */
Bdesc,
static_cast<const void*>(&beta), /* beta */
D, /* C */
Ddesc,
D, /* D */
Ddesc,
&algoArr[algo].algo, /* algo */
workspace, /* workspace */
workspaceSize,
profilingStream)); /* stream */
}
NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));
tuning_clock::duration algoTime = tuning_clock::now() - startTime;
if (algoTime < bestTime)
{
bestAlgo = algo;
bestTime = algoTime;
}
}
NVTE_CHECK_CUDA(hipStreamDestroy(profilingStream));
if (bestAlgo >= 0)
{
if (logTuning)
std::cout << "[INFO] Select hipBLASLt algo " << bestAlgo << " with time "
<< std::chrono::duration_cast<std::chrono::nanoseconds>(bestTime).count() / tuneLoopCount
<< " ns" << std::endl;
}
}
else if (firstAlgo < algoTuneCount)
{
bestAlgo = firstAlgo;
}
if (bestAlgo < 0) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
throw std::runtime_error("Unable to find any suitable algorithms");
}
cached_algo.algo = algoArr[bestAlgo].algo;
cached_algo.index = bestAlgo;
cached_algo.algoId = cached_algo.getAlgoId(algoArr[bestAlgo].algo);
cached_algo.ws_size_min = algoArr[bestAlgo].workspaceSize;
cached_algo.ws_size_max = workspaceSize;
if (logTuning)
std::cout << "[INFO] Use hipBLASLt algo [" << bestAlgo << "] " << cached_algo.algoId << std::endl;
algoCache.store(gemm_cfg, cached_algo);
}
}
// D = alpha * (A * B) + beta * C
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle,
operationDesc,
static_cast<const void*>(&one), /* alpha */
A, /* A */
Adesc,
B, /* B */
Bdesc,
static_cast<const void*>(&beta), /* beta */
D, /* C */
Ddesc,
D, /* D */
Ddesc,
&cached_algo.algo.value(), /* algo */
workspace, /* workspace */
workspaceSize,
stream)); /* stream */
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
}
#endif //USE_HIPBLASLT
#ifdef USE_ROCBLAS // Use rocblas + kernel, no fusion
void rocblas_gemm(const Tensor *inputA,
const Tensor *inputB,
Tensor *outputD,
const Tensor *inputBias,
Tensor *outputPreGelu,
int m, int n, int k,
int lda, int ldb, int ldd,
rocblas_operation transa,
rocblas_operation transb,
bool grad,
void* workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
hipStream_t stream
) {
void *A = inputA->data.dptr;
void *A_scale_inverse = inputA->scale_inv.dptr;
void *B = inputB->data.dptr;
void *B_scale_inverse = inputB->scale_inv.dptr;
void *C = outputD->data.dptr;
void *D = outputD->data.dptr;
void *D_scale = outputD->scale.dptr;
void *D_amax = outputD->amax.dptr;
void *bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr;
void *pre_gelu_out = outputPreGelu->data.dptr;
const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
const rocblas_datatype A_type = get_rocblas_dtype(inputA->data.dtype);
const rocblas_datatype B_type = get_rocblas_dtype(inputB->data.dtype);
const rocblas_datatype D_type = get_rocblas_dtype(outputD->data.dtype);
const rocblas_datatype bias_type = get_rocblas_dtype(inputBias->data.dtype);
const rocblas_datatype gelu_type = get_rocblas_dtype(outputPreGelu->data.dtype);
// check consistency of arguments:
// if fp8 is desired, context cannot be null
// fp8 + gelu fusion + fp8 aux is unavailable right now.
if (use_fp8 && gelu) {
NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
"fp8 Aux output for gemm + gelu fusion not supported!");
}
if (is_fp8_dtype(outputD->data.dtype)) {
NVTE_CHECK(!accumulate,
"Accumulation mode not supported with FP8 GEMM output!");
}
// fp8 + grad unavailable in upstream
NVTE_CHECK(!(use_fp8 && grad), "fp8 + grad not supported!");
float one = 1.0;
float zero = 0.0;
float beta = (accumulate) ? one : zero;
float alpha = 1.0;
if (use_fp8) {
float A_scale_inv, B_scale_inv;
(void)hipMemcpy(&A_scale_inv, A_scale_inverse, sizeof(float), hipMemcpyDeviceToHost);
(void)hipMemcpy(&B_scale_inv, B_scale_inverse, sizeof(float), hipMemcpyDeviceToHost);
alpha = A_scale_inv * B_scale_inv;
}
rocblas_handle handle;
NVTE_CHECK_ROCBLAS(rocblas_create_handle(&handle));
NVTE_CHECK_ROCBLAS(rocblas_set_stream(handle, stream));
// extract the stream order alloc env
bool stream_order_alloc = false;
if (const char* env_p = std::getenv("ROCBLAS_STREAM_ORDER_ALLOC") ) {
if (env_p != nullptr && std::string(env_p) == "1")
stream_order_alloc = true;
}
int64_t ld_gelumat = (int64_t) ldd;
NVTE_CHECK((A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r && D_type==rocblas_datatype_f16_r) ||
(A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_bf16_r) ||
(A_type==rocblas_datatype_f32_r && B_type==rocblas_datatype_f32_r && D_type==rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f16_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_bf16_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f8_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_bf8_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_bf8_r && D_type==rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_bf8_r && D_type==rocblas_datatype_f16_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_bf8_r && D_type==rocblas_datatype_bf16_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_bf8_r && D_type==rocblas_datatype_f8_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_bf8_r && D_type==rocblas_datatype_bf8_r) ||
(A_type==rocblas_datatype_bf8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_bf8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f16_r) ||
(A_type==rocblas_datatype_bf8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_bf16_r)||
(A_type==rocblas_datatype_bf8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f8_r)||
(A_type==rocblas_datatype_bf8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_bf8_r),
"Only the following combinations of data types are enabled now!\n\
1. input: fp32, output: fp32.\n\
2. input: fp16, output: fp16.\n\
3. input: bf16, output: bf16.\n\
4. input: fp8/bf8, output: fp8/bf8, fp16/bf16, fp32");
//If D is not fp32, then we need a temp buffer for GEMM result before applying epilogues. Otherwise, we can apply epilogues in-place.
// with bias or gelu, allocate fp32 D_temp if the output is not fp32
// with input fp8/bf8 (use_fp8) and bf16 output, need a fp32 D_temp, as rocblas does not support this case (fp8/bf8 input fp16/fp32 output is supported)
// with use_fp8 true and fp8/bf8 output, need fp32 D_temp to support amax and scale operation
void* D_temp;
if (((bias || gelu) && (D_type==rocblas_datatype_f16_r ||D_type==rocblas_datatype_bf16_r))||
(use_fp8 && (D_type==rocblas_datatype_bf16_r||D_type==rocblas_datatype_f8_r||D_type==rocblas_datatype_bf8_r))) {
if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&D_temp, sizeof(float)*m*n) );
}else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&D_temp, sizeof(float)*m*n, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
}
}else {
D_temp = D;
}
// When Ti=To=fp16 and there is no bias or gelu, D_temp points to D and we would like it to be fp16
rocblas_datatype D_temp_type = rocblas_datatype_f32_r;
if (!(bias || gelu) && (A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r && D_type==rocblas_datatype_f16_r)) {
D_temp_type = rocblas_datatype_f16_r;
}
// When Ti=To=bf16 and there is no bias or gelu, D_temp points to D and we would like it to be bf16
if (!(bias || gelu) && (A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_bf16_r)) {
D_temp_type = rocblas_datatype_bf16_r;
}
// When Ti in fp8 or bf8, To=fp16, there is no bias or gelu, D_temp points to D and we would like it to be fp16, as rocblas support this case.
if ((!(bias||gelu))&& (use_fp8 && D_type==rocblas_datatype_f16_r)) {
D_temp_type = rocblas_datatype_f16_r;
}
if(accumulate && (D_temp!=D || D_temp_type!=D_type)){
DType output_dtype = get_transformer_engine_dtype(D_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType,
//D_temp allocated only with fp32
detail::identity_kernelLauncher<OType, float>(reinterpret_cast<const OType*>(D),
reinterpret_cast<float*>(D_temp),
m*n,
stream);
);
}
// D = alpha * (A * B) + beta * C
if (use_fp8) {
rocblas_computetype computeType = rocblas_compute_type_f32;
NVTE_CHECK_ROCBLAS(rocblas_gemm_ex3(handle, transa, transb, m, n, k, &alpha,
A, A_type, lda,
B, B_type, ldb,
&beta, D_temp, D_temp_type, ldd, D_temp, D_temp_type, ldd,
computeType, rocblas_gemm_algo::rocblas_gemm_algo_standard,0,0));
}else {
rocblas_datatype computeType = rocblas_datatype_f32_r;
uint32_t flags = rocblas_gemm_flags_none;
if((A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r) && grad){
flags = rocblas_gemm_flags_fp16_alt_impl;
}
NVTE_CHECK_ROCBLAS(rocblas_gemm_ex(handle, transa, transb, m, n, k, &alpha,
A, A_type, lda,
B, B_type, ldb,
&beta, D_temp, D_temp_type, ldd, D_temp, D_temp_type, ldd,
computeType, rocblas_gemm_algo::rocblas_gemm_algo_standard,0,flags));
}
NVTE_CHECK_ROCBLAS(rocblas_destroy_handle(handle));
int batch_size, input_dim, output_dim;
if (bias && gelu) {
if (grad) {
// epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
// Apply GELU gradient to D_temp and store in D
// Apply bias gradient to D (D is already the result of GELU gradient) and store in bias_ptr;
// This case is NN
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// The bias vector length is m. So it will be reduced along axis 0 in row major
// (TODO): The cublasLt doc is not very clear wrt the bias gradient here.
// It does not explicitly say that it goes through GELU gradient first. We will need to
// confirm in the future. As of now, my implementation for the bias gradient takes
// the GELU gradient result in lower precision (D). It might be better to take the GELU
// gradient result in fp32 but as it requires some kernel changes I would only do that
// once we confirm that this is the right form of the epilogue.
// This is for linear1 -> gelu -> linear2
// compute dX = dY * W for linear2
// gemm_ex(A=W, B=dY)
batch_size = n;
input_dim = m; // input dimension of the second linear layer is the output dimension of the first linear layer
output_dim = k;
DType output_dtype = get_transformer_engine_dtype(D_type);
DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(gelu_dtype, GType,
detail::gelu_backward_kernelLauncher<OType, GType>(reinterpret_cast<const float*>(D_temp),
reinterpret_cast<OType*>(D),
reinterpret_cast<const GType*>(pre_gelu_out),
batch_size,
input_dim,
stream);
);
);
void* bias_tmp;
if (bias_type != rocblas_datatype_f32_r) {
if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*input_dim) ); // The bias gradient is for the first linear layer
}else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*input_dim, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
}
}else {
bias_tmp = bias_ptr;
}
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType,
detail::bias_gradient_kernelLauncher<OType>(reinterpret_cast<const OType*>(D),
reinterpret_cast<float*>(bias_tmp),
batch_size,
input_dim,
stream_order_alloc,
stream);
);
if (bias_type != rocblas_datatype_f32_r) {
DType bias_dtype = get_transformer_engine_dtype(bias_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(bias_dtype, BType,
detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp),
reinterpret_cast<BType*>(bias_ptr),
input_dim,
stream);
);
if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(bias_tmp) );
}else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
}
}
} else {
// epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
// Add bias_ptr to D_temp and store in pre_gelu_out, and apply GELU to the pre_gelu_output and then store in D
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// gemm_ex(A=W, B=X, transA=T)
batch_size = n;
input_dim = k;
output_dim = m;
DType output_dtype = get_transformer_engine_dtype(D_type);
DType bias_dtype = get_transformer_engine_dtype(bias_type);
DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(gelu_dtype, GType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(bias_dtype, BType,
detail::add_bias_gelu_kernelLauncher<OType, GType, BType>(reinterpret_cast<const float*>(D_temp),
reinterpret_cast<OType*>(D),
reinterpret_cast<GType*>(pre_gelu_out),
reinterpret_cast<const BType*>(bias_ptr),
reinterpret_cast<float*>(D_amax),
reinterpret_cast<const float*>(D_scale),
batch_size,
output_dim,
stream);
);
);
);
}
}else if (bias) {
if (grad) {
// grad output is always input B
// epilogue = CUBLASLT_EPILOGUE_BGRADB;
// Apply bias gradient to matrix B and store in bias_ptr, reduce along the k dimension, output bias length is n
// As B is transposed, is of shape (n, k) in column major, and is of shape (k, n) in row major.
// bias gradient vector length is n. So it will be reduced along axis 0 in row major.
// The backward pass calculate the bias gradient along with dW = dY^T * X
// gemm_ex(A=X, B = dY, transB=T)
batch_size = k;
input_dim = m;
output_dim = n;
void * bias_tmp;
if (bias_type != rocblas_datatype_f32_r) {
if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*output_dim) );
}else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*output_dim, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
}
}else {
bias_tmp = bias_ptr;
}
DType input_dtype = get_transformer_engine_dtype(B_type);
DType output_dtype = get_transformer_engine_dtype(D_type);
DType bias_dtype = get_transformer_engine_dtype(bias_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input_dtype, IType,
detail::bias_gradient_kernelLauncher<IType>(reinterpret_cast<const IType*>(B),
reinterpret_cast<float*>(bias_tmp),
batch_size,
output_dim,
stream_order_alloc,
stream);
);
if (bias_type != rocblas_datatype_f32_r) {
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(bias_dtype, BType,
detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp),
reinterpret_cast<BType*>(bias_ptr),
output_dim,
stream);
);
if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(bias_tmp) );
}else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
}
}
if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) {
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType,
detail::identity_kernelLauncher<float, OType>(reinterpret_cast<const float*>(D_temp),
reinterpret_cast<OType*>(D),
input_dim*output_dim,
stream);
);
}
} else {
// epilogue = CUBLASLT_EPILOGUE_BIAS;
// Broadcast bias and add it to D_temp and store in D. The bias vector length is m
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// gemm_ex(A=W, B=X, transA=T)
batch_size = n;
input_dim = k;
output_dim = m;
DType output_dtype = get_transformer_engine_dtype(D_type);
DType bias_dtype = get_transformer_engine_dtype(bias_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(bias_dtype, BType,
detail::add_bias_kernelLauncher<OType, BType>(reinterpret_cast<const float*>(D_temp),
reinterpret_cast<OType*>(D),
reinterpret_cast<const BType*>(bias_ptr),
reinterpret_cast<float*>(D_amax),
reinterpret_cast<const float*>(D_scale),
batch_size,
output_dim,
stream);
);
);
}
}else if (gelu) {
if (grad) {
// epilogue = CUBLASLT_EPILOGUE_DGELU;
// Take input from pre_gelu_out and apply GELU gradients to D_temp and store result in D
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// gemm_ex(A=W, B=dY)
batch_size = n;
input_dim = m;
output_dim = k;
DType output_dtype = get_transformer_engine_dtype(D_type);
DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(gelu_dtype, GType,
detail::gelu_backward_kernelLauncher<OType, GType>(reinterpret_cast<const float*>(D_temp),
reinterpret_cast<OType*>(D),
reinterpret_cast<const GType*>(pre_gelu_out),
batch_size,
input_dim,
stream);
);
);
} else {
// epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
// Store (quantized) D_temp in pre_gelu_out, and apply GELU to D_temp then store in D
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// gemm_ex(A=W, B=X, transA=T)
batch_size = n;
input_dim = k;
output_dim = m;
DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(gelu_dtype, GType,
detail::identity_kernelLauncher<float, GType>(reinterpret_cast<const float*>(D_temp),
reinterpret_cast<GType*>(pre_gelu_out),
batch_size*output_dim,
stream);
);
DType output_dtype = get_transformer_engine_dtype(D_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType,
detail::gelu_forward_kernelLauncher<OType>(reinterpret_cast<const float*>(D_temp),
reinterpret_cast<OType*>(D),
reinterpret_cast<float*>(D_amax),
reinterpret_cast<const float*>(D_scale),
batch_size,
output_dim,
stream);
);
}
} else { // No epilogue - !(bias || gelu)
if (use_fp8 && (D_type==rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r || D_type == rocblas_datatype_bf8_r)) {
DType output_dtype = get_transformer_engine_dtype(D_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType,
detail::identity_output_kernelLauncher<OType>(reinterpret_cast<const float*>(D_temp),
reinterpret_cast<OType*>(D),
reinterpret_cast<float*>(D_amax),
reinterpret_cast<const float*>(D_scale),
m*n,
stream);
);
}
}
if (((bias || gelu) && (D_type==rocblas_datatype_f16_r ||D_type==rocblas_datatype_bf16_r))||
(use_fp8 && (D_type==rocblas_datatype_bf16_r || D_type==rocblas_datatype_f8_r || D_type==rocblas_datatype_bf8_r))) {
if(! stream_order_alloc){
NVTE_CHECK_CUDA( hipFree(D_temp) );
}else{
#if HIP_VERSION >= 50300000
NVTE_CHECK_CUDA( hipFreeAsync(D_temp, stream) );
#else
NVTE_ERROR("Stream order allocation is supported on ROCm 5.3 and above.");
#endif
}
}
}
#endif //USE_ROCBLAS
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, bool transa, bool transb, bool grad,
void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, int m_split, int n_split, bool gemm_producer,
const Tensor *inputCounter, hipStream_t stream)
{
/*If no backend is specified with env variable use HIPBLASLT unless it is disabled
If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
Otherwise use ROCBLAS
*/
bool use_hipblaslt = std::getenv("NVTE_USE_HIPBLASLT") != nullptr;
bool use_rocblas = std::getenv("NVTE_USE_ROCBLAS") != nullptr;
#if !defined(USE_HIPBLASLT) && !defined(USE_ROCBLAS)
#error GEMM backend is not specified
#elif !defined(USE_HIPBLASLT)
if (use_hipblaslt)
{
use_hipblaslt = false;
std::cout << "[NOTICE] hipBLASLt is not enabled, NVTE_USE_HIPBLASLT env is ignored\n";
}
#elif !defined(USE_ROCBLAS)
if (use_rocblas)
{
use_rocblas = false;
std::cout << "[NOTICE] rocBLAS is not enabled, NVTE_USE_ROCBLAS env is ignored\n";
}
#else
if (use_hipblaslt && use_rocblas)
{
use_rocblas = false;
std::cout << "[NOTICE] Two GEMM backend are enabled, hipBLASLt will be used\n";
}
#endif
#ifdef USE_HIPBLASLT
if (use_hipblaslt || !use_rocblas)
{
hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu,
m, n, k, lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad,
workspace, workspaceSize, accumulate, use_split_accumulator,
math_sm_count, m_split, n_split, gemm_producer,
inputCounter, stream);
return;
}
#endif
#ifdef USE_ROCBLAS
{
rocblas_gemm(inputA, inputB, outputD, inputBias, outputPreGelu,
m, n, k, lda, ldb, ldd,
(transa) ? rocblas_operation_transpose : rocblas_operation_none,
(transb) ? rocblas_operation_transpose : rocblas_operation_none,
grad,
workspace, workspaceSize, accumulate, use_split_accumulator,
math_sm_count, m_split, n_split, gemm_producer,
inputCounter, stream);
}
#endif
}
} //namespace transformer_engine
\ No newline at end of file
......@@ -109,6 +109,21 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVT
NVTETensor* workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
#ifdef __HIP_PLATFORM_AMD__
void nvte_multi_stream_cublas_batchgemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
const NVTETensor* bias, NVTETensor* pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
NVTETensor* workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream);
#endif
#ifdef __cplusplus
} // extern "C"
#endif
......@@ -116,8 +131,14 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVT
/*! \namespace transformer_engine
*/
namespace transformer_engine {
#ifdef __HIP_PLATFORM_AMD__
// In dcu, 2 stream is more better
constexpr int num_streams = 2;
// Add for batchgemm stream
constexpr int num_batchgemm_streams = 1;
#else
constexpr int num_streams = 4;
#endif
} // namespace transformer_engine
......
......@@ -39,10 +39,12 @@ Compute always in FP32
namespace transformer_engine {
namespace normalization {
#ifndef __HIP_PLATFORM_AMD__
cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) {
return training ? cudnn_frontend::NormFwdPhase_t::TRAINING
: cudnn_frontend::NormFwdPhase_t::INFERENCE;
}
#endif
TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype,
......@@ -195,6 +197,10 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
_training(training),
_norm_stage(NormStage),
_norm_type(NormType) {
#ifdef USE_ROCM
static_assert(false,
"Cudnn backend is not surpported in rocm for normalization yet.");
#else
static_assert(CUDNN_FRONTEND_VERSION >= 10601,
"CUDNN_FRONTEND_VERSION should be at least 1.6.1!");
......@@ -378,9 +384,14 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
}
// Build the graph
this->_build();
#endif
}
void CudnnNormalizationPlan::_build() {
#ifdef USE_ROCM
static_assert(false,
"Cudnn backend is not surpported in rocm for normalization yet.");
#else
NVTE_CHECK(_graph.validate().is_good());
NVTE_CHECK(_graph.build_operation_graph(_handle).is_good());
NVTE_CHECK(_graph
......@@ -390,15 +401,25 @@ void CudnnNormalizationPlan::_build() {
NVTE_CHECK(_graph.check_support(_handle).is_good());
NVTE_CHECK(
_graph.build_plans(_handle, cudnn_frontend::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good());
#endif
}
std::vector<size_t> CudnnNormalizationPlan::getWorkspaceShape() const {
#ifdef USE_ROCM
static_assert(false,
"Cudnn backend is not surpported in rocm for normalization yet.");
#else
return {static_cast<size_t>(_graph.get_workspace_size())};
#endif
}
void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr,
void* mean_dptr, void* eps_dptr, void* rsigma_dptr,
void* workspace_dptr, cudaStream_t stream) {
#ifdef USE_ROCM
static_assert(false,
"Cudnn backend is not surpported in rocm for normalization yet.");
#else
// Binding data pointers to graph tensors
_variant_pack = {{_x, x_dptr}, {_eps, eps_dptr}};
......@@ -433,12 +454,17 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
// Execute the computation
NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream));
NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good());
#endif
}
void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr,
void* rsigma_dptr, void* dx_dptr, void* dz_dptr,
void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
cudaStream_t stream) {
#ifdef USE_ROCM
static_assert(false,
"Cudnn backend is not surpported in rocm for normalization yet.");
#else
// Binding data pointers to graph tensors
_variant_pack = {
{_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}};
......@@ -455,6 +481,7 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_
// Execute the computation
NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream));
NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good());
#endif
}
NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
......@@ -491,13 +518,21 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
}
bool& _cudnn_norm_fwd_flag() {
#ifdef USE_ROCM
return false;
#else
static bool flag = transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN");
return flag;
#endif
}
bool& _cudnn_norm_bwd_flag() {
#ifdef USE_ROCM
return false;
#else
static bool flag = transformer_engine::getenv<bool>("NVTE_NORM_BWD_USE_CUDNN");
return flag;
#endif
}
bool use_cudnn_norm_fwd() { return _cudnn_norm_fwd_flag(); }
......@@ -508,10 +543,18 @@ bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); }
void nvte_enable_cudnn_norm_fwd(bool enable) {
NVTE_API_CALL(nvte_enable_cudnn_norm_fwd);
#ifdef USE_ROCM
transformer_engine::normalization::_cudnn_norm_bwd_flag() = false;
#else
transformer_engine::normalization::_cudnn_norm_fwd_flag() = enable;
#endif
}
void nvte_enable_cudnn_norm_bwd(bool enable) {
NVTE_API_CALL(nvte_enable_cudnn_norm_bwd);
#ifdef USE_ROCM
transformer_engine::normalization::_cudnn_norm_bwd_flag() = false;
#else
transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable;
#endif
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment