Commit c520cba3 authored by yuguo's avatar yuguo
Browse files

[DCU] Preliminary adaptation

parent 5b6ef054
......@@ -6,6 +6,7 @@ from typing import Iterable, Optional
import pytest
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
......@@ -258,7 +259,7 @@ class TestFP8Recipe:
# Compute scale
max_val = {
"forward": 448.0,
"forward": 448.0 if not IS_HIP_EXTENSION else 240.0,
"backward": 57344.0,
}[stage]
ref_scale = (max_val / ref_amax) / (2**margin)
......
......@@ -9,6 +9,7 @@ from contextlib import nullcontext
import torch
import pytest
import os
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.fp8 import (
fp8_autocast,
......@@ -51,6 +52,12 @@ def create_meta(scale_factor: float, size: int = 1):
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
return meta
if IS_HIP_EXTENSION:
from functools import cache
@cache
def use_hipblaslt() -> bool:
return (os.getenv("NVTE_USE_HIPBLASLT") is not None
or os.getenv("NVTE_USE_ROCBLAS") is None )
def custom_amax_to_scale(
amax: torch.Tensor,
......@@ -982,6 +989,10 @@ def test_sanity_gradient_accumulation_fusion(
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
if IS_HIP_EXTENSION:
if not use_hipblaslt():
pytest.skip("CUDA graph capture not supported with rocBLAS path")
config = model_configs[model]
if fp8_recipe is not None:
......
......@@ -4,44 +4,100 @@
cmake_minimum_required(VERSION 3.21)
option(USE_ROCM "Use ROCm" OFF)
option(USE_HIPBLASLT "Use HIPBLASLT" ON)
# Temp unsupport aottriton\ck backend and Use ROCBLAS
option(USE_ROCBLAS "Use ROCBLAS" OFF)
if(NOT USE_ROCM)
if(((EXISTS "/opt/dtk/") OR (EXISTS $ENV{ROCM_PATH})) AND NOT (EXISTS "/bin/nvcc"))
message("hcu detected.")
set(USE_ROCM ON)
endif()
endif()
if (USE_ROCM)
add_compile_definitions(__HIP_CLANG_ONLY__=1)
if (NOT USE_HIPBLASLT AND NOT USE_ROCBLAS)
message(FATAL_ERROR "Need specify at least one GEMM library to use: HIPBLASLT or ROCBLAS")
endif()
unset(USE_CUDA)
else()
set(USE_CUDA TRUE)
endif()
# Language options
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if(USE_CUDA)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G")
endif()
endif()
# Hide non-necessary symbols in shared object.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version")
# Hide non-necessary symbols in shared object.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libtransformer_engine.version")
# Transformer Engine library
project(transformer_engine LANGUAGES CUDA CXX)
# Transformer Engine library
project(transformer_engine LANGUAGES CUDA CXX)
# CUDA Toolkit
find_package(CUDAToolkit REQUIRED)
if (CUDAToolkit_VERSION VERSION_LESS 12.0)
# CUDA Toolkit
find_package(CUDAToolkit REQUIRED)
if (CUDAToolkit_VERSION VERSION_LESS 12.0)
message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}")
endif()
endif()
# cuDNN frontend API
set(CUDNN_FRONTEND_INCLUDE_DIR
# cuDNN frontend API
set(CUDNN_FRONTEND_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include")
if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}")
if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find cuDNN frontend API at ${CUDNN_FRONTEND_INCLUDE_DIR}. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
endif()
include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
else()
set(CMAKE_CXX_STANDARD 17)
project(transformer_engine LANGUAGES HIP CXX)
# Disable Asserts In Code (Can't use asserts on HIP stack.)
add_definitions(-DNDEBUG)
add_definitions(-DUSE_ROCM)
# Change clang++ to hipcc
SET(CMAKE_CXX_COMPILER "${ROCM_PATH}/bin/hipcc")
if(NOT DEFINED ENV{NVTE_ROCM_ARCH})
SET(CMAKE_HIP_ARCHITECTURES gfx906;gfx926;gfx928;gfx936)
else()
SET(CMAKE_HIP_ARCHITECTURES $ENV{NVTE_ROCM_ARCH})
endif()
# build error will be dup-ed parallel-jobs times
# set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -parallel-jobs=4")
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -g")
endif()
list(APPEND CMAKE_MODULE_PATH "/opt/dtk")
endif()
set(message_line "-------------------------------------------------------------")
message("${message_line}")
message(STATUS "USE_ROCM ${USE_ROCM}")
if(USE_ROCM)
message(STATUS "CMAKE_HIP_ARCHITECTURES: ${CMAKE_HIP_ARCHITECTURES}")
message(STATUS "USE_HIPBLASLT ${USE_HIPBLASLT} USE_ROCBLAS ${USE_ROCBLAS}")
endif()
include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
......@@ -49,7 +105,9 @@ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
# Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES
if(USE_CUDA)
list(APPEND transformer_engine_SOURCES
cudnn_utils.cpp
transformer_engine.cpp
common.cu
......@@ -92,17 +150,114 @@ list(APPEND transformer_engine_SOURCES
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
else()
list(APPEND transformer_engine_SOURCES
cudnn_utils.cpp
transformer_engine.cpp
common.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
gemm/cublaslt_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/rtc.cpp
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
# process source code files
message("${message_line}")
message(STATUS "CMAKE_CURRENT_SOURCE_DIR: ${CMAKE_CURRENT_SOURCE_DIR}")
message(STATUS "PROJECT_SOURCE_DIR: ${PROJECT_SOURCE_DIR}")
set(TE ${CMAKE_CURRENT_SOURCE_DIR}/../..)
set(THIRDPARTY ${TE}/3rdparty)
list(APPEND CMAKE_MODULE_PATH "${THIRDPARTY}/hipify_torch/cmake")
include(Hipify)
message(STATUS "CMAKE_MODULE_PATH: ${CMAKE_MODULE_PATH}")
set(header_include_dir
${CMAKE_CURRENT_SOURCE_DIR}/comm_gemm_overlap/userbuffers
${CMAKE_CURRENT_SOURCE_DIR}/activation
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/transpose
${CMAKE_CURRENT_SOURCE_DIR}/util
${CMAKE_CURRENT_SOURCE_DIR}/normalization
${CMAKE_CURRENT_SOURCE_DIR}/normalization/rmsnorm
${CMAKE_CURRENT_SOURCE_DIR}/normalization/layernorm
${CMAKE_CURRENT_SOURCE_DIR})
message(STATUS "HIPIFY CUDA_SOURCE_DIR: ${CMAKE_CURRENT_SOURCE_DIR}")
message(STATUS "HIPIFY HEADER_INCLUDE_DIR: ${header_include_dir}")
hipify(CUDA_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}
HEADER_INCLUDE_DIR ${header_include_dir}
IGNORES "*/amd_detail/*"
IGNORES "*/fused_attn/*"
CUSTOM_MAP_FILE "${TE}/hipify_custom_map.json"
)
get_hipified_list("${transformer_engine_SOURCES}" te_hip_sources)
message("${message_line}")
message(STATUS "nvte hipified sources: ${te_hip_sources}")
add_library(transformer_engine SHARED ${te_hip_sources})
endif()
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
if (USE_CUDA)
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart)
target_include_directories(transformer_engine PRIVATE
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
else()
# Aotriton is currently unsupported
set(AotritonAndCk_fused_attn "unsupported")
find_package(hip)
list(APPEND transformer_engine_LINKER_LIBS hip::host hip::device roctx64)
if(USE_HIPBLASLT)
find_package(hipblaslt)
find_package(hipblas REQUIRED PATHS ${ROCM_PATH})
target_compile_definitions(transformer_engine PUBLIC USE_HIPBLASLT)
list(APPEND transformer_engine_LINKER_LIBS roc::hipblaslt hipblas)
endif()
if(USE_ROCBLAS)
find_package(rocblas)
target_compile_definitions(transformer_engine PUBLIC USE_ROCBLAS)
list(APPEND transformer_engine_LINKER_LIBS roc::rocblas)
endif()
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
endif()
# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
......@@ -113,8 +268,10 @@ if (NVTE_UB_WITH_MPI)
target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI)
endif()
# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
if (USE_CUDA)
# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
endif()
# Helper functions to make header files with C++ strings
function(make_string_header STRING STRING_NAME)
......@@ -130,17 +287,34 @@ function(make_string_header_from_file file_ STRING_NAME)
endfunction()
# Header files with C++ strings
list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path)
make_string_header("${cuda_include_path}"
if(USE_CUDA)
list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path)
make_string_header("${cuda_include_path}"
string_path_cuda_include)
make_string_header_from_file(transpose/rtc/cast_transpose_fusion.cu
make_string_header_from_file(transpose/rtc/cast_transpose_fusion.cu
string_code_transpose_rtc_cast_transpose_fusion_cu)
make_string_header_from_file(transpose/rtc/cast_transpose.cu
make_string_header_from_file(transpose/rtc/cast_transpose.cu
string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.cu
make_string_header_from_file(transpose/rtc/transpose.cu
string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(utils.cuh
make_string_header_from_file(utils.cuh
string_code_utils_cuh)
else()
make_string_header_from_file(utils_hip.cuh
string_code_utils_cuh)
make_string_header_from_file(transpose/rtc/cast_transpose_fusion.hip
string_code_transpose_rtc_cast_transpose_fusion_cu)
make_string_header_from_file(transpose/rtc/cast_transpose.hip
string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.hip
string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(amd_detail/hip_float8.h
string_code_amd_detail_hip_float8_h)
make_string_header_from_file(amd_detail/hip_f8_impl.h
string_code_amd_detail_hip_f8_impl_h)
endif()
make_string_header_from_file(util/math.h
string_code_util_math_h)
target_include_directories(transformer_engine PRIVATE
......@@ -160,8 +334,22 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
endif()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
if(USE_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
else()
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -O3")
set(HIP_HCC_FLAGS "${CMAKE_HIP_FLAGS} -mavx2 -mf16c -mfma -std=c++17")
# Ask hcc to generate device code during compilation so we can use
# host linker to link.
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -fno-gpu-rdc -Wno-defaulted-function-deleted")
foreach(rocm_arch ${CMAKE_HIP_ARCHITECTURES})
# if CMAKE_CXX_FLAGS has --offload-arch set already, better to rm first
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} --offload-arch=${rocm_arch}")
endforeach()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${HIP_HCC_FLAGS}")
endif()
# Number of parallel build jobs
if(ENV{MAX_JOBS})
......
......@@ -124,6 +124,9 @@ def _load_nvrtc():
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
try:
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
except OSError:
pass
_TE_LIB_CTYPES = _load_library()
/*************************************************************************
* Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
#include <hip/hip_runtime.h>
namespace hip_f8_impl {
HIP_HOST_DEVICE inline int clz(uint32_t x) {
#ifdef __HIP_DEVICE_COMPILE__
return __clz(x);
#else
return __builtin_clz(x);
#endif
}
template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
HIP_HOST_DEVICE
uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) {
constexpr bool is_half = std::is_same<T,__half>::value;
constexpr bool is_float = std::is_same<T,float>::value;
static_assert(wm+we==7, "wm+we==7");
static_assert(is_half || is_float, "Only half and float can be cast to f8");
const int mfmt = (sizeof(T)==4) ? 23 : 10;
uint32_t x;
if(sizeof(T)==4)
x = reinterpret_cast<uint32_t&>(_x);
else
x = reinterpret_cast<uint16_t&>(_x);
uint32_t y, head, mantissa;
int exponent, bias;
uint32_t sign;
if(sizeof(T)==4) {
head = x & 0xFF800000;
mantissa = x & 0x7FFFFF;
exponent = (head>>23) & 0xFF;
sign = head >> 31;
bias = 127;
} else {
head = x & 0xFC00;
mantissa = x & 0x3FF;
exponent = (head>>10) & 0x1F;
sign = head >> 15;
bias = 15;
}
uint32_t signed_inf = (sign<<7) + (((1<<we)-1)<<wm);
// Deal with inf and NaNs
if(negative_zero_nan) {
if(sizeof(T)==4) {
if((x & 0x7F800000) == 0x7F800000)
return 0x80;
} else {
//if(__hisinf(x) || __hisnan(x))
if((x & 0x7C00)==0x7C00)
return 0x80;
}
}
else {
if(sizeof(T)==4) {
if((x & 0x7F800000) == 0x7F800000)
return signed_inf + (mantissa!=0 ? 1 : 0);
} else {
if((x & 0x7C00)==0x7C00)
return signed_inf + (mantissa!=0 ? 1 : 0);
}
}
if(x==0)
return 0;
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
const int f8_bias = ( 1<<(we-1) ) - 1 + ( negative_zero_nan ? 1 : 0 );
const int f8_denormal_act_exponent = 1 - f8_bias; //actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int act_exponent, f8_exponent, exponent_diff;
if (exponent == 0) { // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 here.
In this case, f8 is usually in denormal. But there could be exceptions.
fp16 denormal has exponent bias 15 while bf8 with NANOO has exponent bias 16.
It means that there are some numbers in fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15.
fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1;
exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal
}
else { // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if (act_exponent <= f8_denormal_act_exponent) {
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = f8_denormal_act_exponent - act_exponent;
}
else { //both fp32/fp16 and f8 are in normal range
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference for this case,
//act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa += (1 << mfmt); //Add the implicit 1 into mantissa
}
bool midpoint;
if (exponent_diff<=wm+1)
// The determinination of midpoint only makes sense when wm+1 could compensate the difference in exponent.
// Why wm+1 instead of wm? It is because in additional to the wm bits to be left as f8 mantissa, there is
// also the implicit 1 (There is not always implicit 1 but it does not matter).
midpoint = (mantissa & ( (1 << (mfmt-wm+exponent_diff)) - 1 )) == ( 1 << (mfmt-wm+exponent_diff-1) );
else
midpoint = false;
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift right
as shift right could rip off some residual part and make something not midpoint look like midpoint.
For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than midpoint,
but after shift right by 4 bits, it would look like midpoint.
*/
if (exponent_diff>0)
// Clip the exponent_diff as if the right shift when exponent_diff > 31 is undefined behavior
mantissa >>= (exponent_diff > 31 ? 31 : exponent_diff);
else if (exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << mfmt);
//if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
f8_exponent = (act_exponent+exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one?0:1);
//Now we have the exponent and mantissa adjusted
uint32_t drop_mask = (1 << (mfmt-wm)) - 1;
//bool midpoint = (mantissa & drop_mask) == ( 1 << (mfmt-wm-1) );
bool odd = mantissa & (1<< (mfmt-wm)); // if the least significant bit that is not truncated is 1
mantissa += (stoch ? rng : (midpoint?(odd?mantissa:mantissa-1 ) :mantissa) ) & drop_mask;
//Now we deal with overflow
if (f8_exponent == 0) {
if ((1 << mfmt) & mantissa) {
f8_exponent = 1; //denormal overflow to become normal, promote exponent
//mantissa &= (1<<mfmt) -1 ; //No need to make 1 implicit now as it will be addressed later
}
}
else {
if ((1 << (mfmt+1)) & mantissa) {
mantissa >>= 1;
f8_exponent++;
//mantissa &= (1<<mfmt) -1 ; // No need to make 1 implicit now as it will be addressed later
}
}
mantissa >>= (mfmt-wm);
// above range: quantize to maximum possible float of the same sign
const int max_exp = (1<<we)-(negative_zero_nan ? 1 : 2);
if(f8_exponent > max_exp) {
if(clip) {
mantissa = (1<<wm)-1;
f8_exponent = max_exp;
} else {
return signed_inf;
}
}
if(f8_exponent == 0 && mantissa == 0)
return negative_zero_nan? 0 : (sign<<7);
mantissa &= (1<<wm)-1;
return (sign << 7) | (f8_exponent << wm) | mantissa;
}
/* RTC does not have std::conditional so implement it here*/
template<bool B, class T, class F>
struct conditional { typedef T type; };
template<class T, class F>
struct conditional<false, T, F> { typedef F type; };
template <int wm, int we, typename T, bool negative_zero_nan>
HIP_HOST_DEVICE
T cast_from_f8(uint8_t x) {
constexpr bool is_half = std::is_same<T,__half>::value;
constexpr bool is_float = std::is_same<T,float>::value;
constexpr bool is_bf16 = std::is_same<T,hip_bfloat16>::value;
static_assert(is_half || is_float, "only half and float are supported");
constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
T fInf, fNegInf, fNaN, fNeg0;
if(is_half) {
const uint16_t ihInf = 0x7C00;
const uint16_t ihNegInf = 0xFC00;
const uint16_t ihNaN = 0x7C01;
const uint16_t ihNeg0 = 0x8000;
fInf = reinterpret_cast<const __half&>(ihInf);
fNegInf = reinterpret_cast<const __half&>(ihNegInf);
fNaN = reinterpret_cast<const __half&>(ihNaN);
fNeg0 = reinterpret_cast<const __half&>(ihNeg0);
} else if(is_float) {
const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000;
fInf = reinterpret_cast<const float&>(ifInf);
fNegInf = reinterpret_cast<const float&>(ifNegInf);
fNaN = reinterpret_cast<const float&>(ifNaN);
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
}
if(x==0)
return 0;
uint32_t sign = x>>7;
uint32_t mantissa = x & ((1<<wm)-1);
int exponent = (x & 0x7F) >> wm;
if(negative_zero_nan) {
if(x==0x80)
return fNaN;
} else {
if(x==0x80)
return fNeg0;
if(exponent == ((1<<we)-1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
}
typename conditional<sizeof(T)==2, uint16_t, uint32_t>::type retval;
if(we==5 && is_half && !negative_zero_nan) {
retval = x<<8;
return reinterpret_cast<const T&>(retval);
}
const int exp_low_cutoff = (1<<(weo-1)) - (1<<(we-1)) + 1 - (negative_zero_nan ? 1 : 0);
//subnormal input
if(exponent == 0) {
//guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + clz(mantissa) - (32-wm);
mantissa <<= sh;
exponent += 1-sh;
/*
exponent++;
while(mantissa<(1<<wm)) {
mantissa <<= 1;
exponent--;
}
*/
mantissa &= ((1<<wm)-1);
}
exponent += exp_low_cutoff-1;
mantissa <<= wmo - wm;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent<=0) {
mantissa |= 1<<wmo;
mantissa >>= 1-exponent;
exponent = 0;
}
if(sizeof(T)==2)
retval = (sign<<15) | (exponent<<10) | mantissa;
else
retval = (sign<<31) | (exponent<<23) | mantissa;
return reinterpret_cast<const T&>(retval);
}
} // namespace hip_f8_impl
\ No newline at end of file
/*************************************************************************
* Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
#pragma once
// FP8 header version 0.3, 2021/05/11
#include <hip/hip_runtime.h>
#define HIP_HOST_DEVICE __host__ __device__
#define HIP_DEVICE __device__
#define HIP_HOST __host__
#define E5M2_AMAX 57344.0
#define E4M3_AMAX 240.0
namespace hip_f8_impl {
template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
HIP_HOST_DEVICE
uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0);
template <int wm, int we, typename T, bool negative_zero_nan>
HIP_HOST_DEVICE
T cast_from_f8(uint8_t x);
} // namespace hip_f8_impl
#include "hip_f8_impl.h"
enum class hip_f8_type {
bf8 = 0, // 1:5:2
fp8 = 1 // 1:4:3
};
enum class hip_f8_rounding_mode {
standard,
stochastic
};
// bias mode bit implementation
//
// For MI100 simulation purpose, we keep a copy of it on the host and device
// (MI300 HW implementation will be different)
//
// The bias mode should only be accessed via its get/set routines.
// The set routine sets both copies to the same value, keeping them in sync
// The get routine will return the device copy for device functions and
// the host copy for host functions
//
// "bias mode optimial"
// => "bias mode bit" = 1
// => bias = 16 for 152, 8 for 143
// => NAN/INF are represented as negative_zero
//
// "bias mode ieee"
// => "bias mode bit" = 0
// => bias = 15 for 152, 7 for 143
// => NAN/INF are represented as per IEEE conventions
#ifndef __HIPCC_RTC__
static bool hip_f8_bias_mode_bit_host = true;
static inline __host__ bool get_hip_f8_bias_mode() {
return hip_f8_bias_mode_bit_host;
}
#endif // __HIPCC_RTC__
#ifdef __HIPCC__
static __device__ bool hip_f8_bias_mode_bit_device = true;
static inline __device__ bool get_hip_f8_bias_mode() {
return hip_f8_bias_mode_bit_device;
}
#ifndef __HIPCC_RTC__
static __global__ void set_hip_f8_bias_mode_bit(bool v) {
hip_f8_bias_mode_bit_device = v;
}
static void set_hip_f8_bias_mode_ieee() {
hipLaunchKernelGGL(set_hip_f8_bias_mode_bit, dim3(1), dim3(1), 0, 0, false);
hip_f8_bias_mode_bit_host = false;
}
static void set_hip_f8_bias_mode_optimal() {
hipLaunchKernelGGL(set_hip_f8_bias_mode_bit, dim3(1), dim3(1), 0, 0, true);
hip_f8_bias_mode_bit_host = true;
}
#endif // __HIPCC_RTC__
#endif // __HIPCC__
template<hip_f8_type T>
struct hip_f8 {
uint8_t data;
// default constructor
HIP_HOST_DEVICE hip_f8() = default;
// constructor from bits
explicit HIP_HOST_DEVICE hip_f8(uint8_t v) {
data = v;
}
// constructor from float
#ifdef __gfx942__
explicit HIP_DEVICE hip_f8(float v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0) {
union {
float fval;
uint32_t i32val;
uint8_t i8val[4];
} val;
uint32_t ival = 0;
val.fval = v;
if (T == hip_f8_type::bf8) { // bf8
if ((val.i32val & 0x7F800000) != 0x7F800000) // propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, E5M2_AMAX, -E5M2_AMAX);
if (rm == hip_f8_rounding_mode::standard) { // RNE rounding
ival = __builtin_amdgcn_cvt_pk_bf8_f32(
val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
data = val.i8val[0];
}
else { //stochastic rounding
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
data = val.i8val[0]; // little endian
}
}
else { // fp8
if ((val.i32val & 0x7F800000) != 0x7F800000) /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, E4M3_AMAX, -E4M3_AMAX);
if (rm == hip_f8_rounding_mode::standard) { // RNE rounding
ival = __builtin_amdgcn_cvt_pk_fp8_f32(
val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
data = val.i8val[0];
}
else { //stochastic rounding
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
data = val.i8val[0]; // little endian
}
}
}
#ifndef __HIPCC_RTC__
explicit HIP_HOST //Code host still uses SW simulated conversion on gfx942
hip_f8(float v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0) {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<2, 5, float, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<2, 5, float, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<3, 4, float, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<3, 4, float, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
}
}
#endif
#else // #ifndef __gfx942__
explicit HIP_HOST_DEVICE // On architectures other than gfx942, both host and device still use SW simulated conversion
hip_f8(float v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0) {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<2, 5, float, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<2, 5, float, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<3, 4, float, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<3, 4, float, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
}
}
#endif // #ifdef __gfx942__
// constructor from half
#ifdef __gfx942__
explicit HIP_DEVICE hip_f8(half v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0)
: hip_f8((float)v, rm, rng)
{
}
#ifndef __HIPCC_RTC__
explicit HIP_HOST //Code host still uses SW simulated conversion on gfx942
hip_f8(half v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0) {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<2, 5, half, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<2, 5, half, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<3, 4, half, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<3, 4, half, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
}
}
#endif
#else // #ifndef __gfx942__
explicit HIP_HOST_DEVICE // On architectures other than gfx942, both host and device still use SW simulated conversion
hip_f8(half v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0) {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<2, 5, half, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<2, 5, half, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
data = hip_f8_impl::cast_to_f8<3, 4, half, true/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
} else {
data = hip_f8_impl::cast_to_f8<3, 4, half, false/*negative_zero_nan*/, true/*clip*/>(v, (rm == hip_f8_rounding_mode::stochastic), rng);
}
}
}
#endif // #ifdef __gfx942__
// constructor from hip_bfloat16
explicit HIP_HOST_DEVICE hip_f8(hip_bfloat16 v, hip_f8_rounding_mode r=hip_f8_rounding_mode::standard, uint32_t rng=0);
// convert to float
#ifdef __gfx942__
HIP_DEVICE operator float() const {
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // dependent of endian
} val;
// assign 8bit data in position [7:0]
val.i32val = 0;
val.i8val[3] = data; // little endian
// upcast
if(T == hip_f8_type::bf8)
val.fval = __builtin_amdgcn_cvt_f32_bf8(val.i32val, 3); // 0 pos
else // fp8
val.fval = __builtin_amdgcn_cvt_f32_fp8(val.i32val, 3); // 0 pos
return val.fval;
}
#ifndef __HIPCC_RTC__
explicit inline HIP_HOST //Code host still uses SW simulated conversion on gfx942
operator float() const {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<2, 5, float, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<2, 5, float, false/*negative_zero_nan*/>(data);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<3, 4, float, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<3, 4, float, false/*negative_zero_nan*/>(data);
}
}
}
#endif
#else // #ifdef __gfx942__
explicit inline HIP_HOST_DEVICE // On architectures other than gfx942, both host and device still use SW simulated conversion
operator float() const {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<2, 5, float, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<2, 5, float, false/*negative_zero_nan*/>(data);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<3, 4, float, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<3, 4, float, false/*negative_zero_nan*/>(data);
}
}
}
#endif // #ifdef __gfx942__
// convert to half
#ifdef __gfx942__
explicit HIP_DEVICE inline operator half() const {
return __half(float(*this));
}
#ifndef __HIPCC_RTC__
explicit inline HIP_HOST //Code host still uses SW simulated conversion on gfx942
operator half() const {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<2, 5, half, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<2, 5, half, false/*negative_zero_nan*/>(data);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<3, 4, half, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<3, 4, half, false/*negative_zero_nan*/>(data);
}
}
}
#endif
#else // #ifndef __gfx942__
explicit inline HIP_HOST_DEVICE // On architectures other than gfx942, both host and device still use SW simulated conversion
operator half() const {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<2, 5, half, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<2, 5, half, false/*negative_zero_nan*/>(data);
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
return hip_f8_impl::cast_from_f8<3, 4, half, true/*negative_zero_nan*/>(data);
} else {
return hip_f8_impl::cast_from_f8<3, 4, half, false/*negative_zero_nan*/>(data);
}
}
}
#endif // #ifdef __gfx942__
// convert to hip_bfloat16
explicit inline HIP_HOST_DEVICE operator hip_bfloat16() const;
// check for zero
inline HIP_HOST_DEVICE bool is_zero() const {
if (get_hip_f8_bias_mode()) {
return data == 0x00;
} else {
return (data == 0x00) || (data == 0x80);
}
}
// check for nan
inline HIP_HOST_DEVICE bool is_nan() const {
if (get_hip_f8_bias_mode()) {
return data == 0x80;
} else {
if (T == hip_f8_type::bf8) {
return
(data == 0x7d) || (data == 0x7e) || (data == 0x7f) ||
(data == 0xfd) || (data == 0xfe) || (data == 0xff);
} else {
return
(data == 0x79) || (data == 0x7a) || (data == 0x7b) || (data == 0x7c) || (data == 0x7d) || (data == 0x7e) || (data == 0x7f) ||
(data == 0xf9) || (data == 0xfa) || (data == 0xfb) || (data == 0xfc) || (data == 0xfd) || (data == 0xfe) || (data == 0xff);
}
}
}
// check for inf
inline HIP_HOST_DEVICE bool is_inf() const {
if (get_hip_f8_bias_mode()) {
return data == 0x80;
} else {
if (T == hip_f8_type::bf8) {
return (data == 0x7c) || (data == 0xfc);
} else {
return (data == 0x78) || (data == 0xf8);
}
}
}
};
#ifdef __HIPCC__
template<hip_f8_type T>
struct hip_f8x4 {
// define some convenience types
typedef float float32x2 __attribute__((ext_vector_type(2)));
typedef float float32x4 __attribute__((ext_vector_type(4)));
typedef _Float16 halfx2 __attribute__((ext_vector_type(2)));
typedef _Float16 halfx4 __attribute__((ext_vector_type(4)));
typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2)));
typedef uint16_t hip_bfloat16x4 __attribute__((ext_vector_type(4)));
uint32_t data;
// default constructor
HIP_HOST_DEVICE hip_f8x4() = default;
// constructor from bits
HIP_HOST_DEVICE hip_f8x4(uint32_t v);
// constructor from float
HIP_HOST_DEVICE hip_f8x4(float v0, float v1=0, float v2=0, float v3=0, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
HIP_HOST_DEVICE hip_f8x4(float32x2 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
HIP_HOST_DEVICE hip_f8x4(float32x4 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
// constructor from half
HIP_HOST_DEVICE hip_f8x4(half v0, half v1=0, half v2=0, half v3=0, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
HIP_HOST_DEVICE hip_f8x4(halfx2 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
HIP_HOST_DEVICE hip_f8x4(halfx4 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
// constructor from hip_bfloat16
HIP_HOST_DEVICE hip_f8x4(hip_bfloat16 v0, hip_bfloat16 v1=hip_bfloat16(0.0f), hip_bfloat16 v2=hip_bfloat16(0.0f), hip_bfloat16 v3=hip_bfloat16(0.0f), hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
HIP_HOST_DEVICE hip_f8x4(hip_bfloat16x2 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
HIP_HOST_DEVICE hip_f8x4(hip_bfloat16x4 v, hip_f8_rounding_mode rm=hip_f8_rounding_mode::standard, uint32_t rng=0);
// convert to float32x4
inline HIP_HOST_DEVICE operator float32x4() const;
// convert to halfx4
inline HIP_HOST_DEVICE operator halfx4() const;
// convert to hip_bfloat16x4
inline HIP_HOST_DEVICE operator hip_bfloat16x4() const;
};
template<hip_f8_type T>
struct hip_f8x8 {
// define some convenience types
typedef hip_f8x4<T> f8x8 __attribute__((ext_vector_type(2)));
f8x8 data;
// default constructor
HIP_HOST_DEVICE hip_f8x8() = default;
// do we need to define other constructors or any conversion routines here?
};
// If we do not end up needing either any constructors or conversion routines for the above type, then
// we can simplify the above type to the following
#if USE_SIMPLER_HIP_F8x8
template <hip_f8_type T>
using hip_f8x8 = hip_f8x4<T> __attribute__((ext_vector_type(2)));
#endif
typedef float hip_float32x4 __attribute__((ext_vector_type(4)));
typedef float hip_float32x16 __attribute__((ext_vector_type(16)));
// these are device-specific and we don't expect them to exist unless we're compiling with hip-clang for gfx942.
template<hip_f8_type T_A, hip_f8_type T_B>
__device__ hip_float32x4 mfma_f32_16x16x32(hip_f8x8<T_A> a, hip_f8x8<T_B> b, hip_float32x4 c);
template<hip_f8_type T_A, hip_f8_type T_B>
__device__ hip_float32x16 mfma_f32_32x32x16(hip_f8x8<T_A> a, hip_f8x8<T_B> b, hip_float32x16 c);
#endif //__HIPCC__
\ No newline at end of file
......@@ -107,8 +107,12 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
This is needed only for Hopper, which uses persistent CTA execution.
*/
int max_connection = transformer_engine::getenv<int>("CUDA_DEVICE_MAX_CONNECTIONS", 8);
#ifdef USE_ROCM
int runtime_version = 6;
#else
int runtime_version = 0;
cudaRuntimeGetVersion(&runtime_version);
#endif
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) {
......@@ -277,7 +281,11 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
assert(rs_output.size(0) == _ubuf.size(0) / _tp_size);
assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
#ifdef USE_ROCM
reducescatter2_userbuff_fp8<hip_f8<hip_f8_type::bf8>>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
#else
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
#endif
comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
} else {
......
......@@ -361,7 +361,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE));
NVTE_CHECK_CUDA(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags =
#ifdef USE_ROCM
reinterpret_cast<int *>((reinterpret_cast<uintptr_t>((*comm)->flags) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK
#else
reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
#endif
using namespace std;
......@@ -624,7 +628,12 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
} else {
#endif
if (alloc) {
#ifdef USE_ROCM
// Ref to RCCL
NVTE_CHECK_CUDA(hipExtMallocWithFlags(gpubuff, bytes, hipDeviceMallocFinegrained));
#else
NVTE_CHECK_CUDA(cudaMalloc(gpubuff, bytes));
#endif
NVTE_CHECK_CUDA(cudaMemset(*gpubuff, 0, bytes));
}
......
......@@ -9,8 +9,10 @@
#include <cuda_runtime.h>
#if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#define half_dtype nv_bfloat16
#else
#include <cuda_fp16.h>
#define half_dtype half
#endif
......@@ -24,6 +26,18 @@
#define MAX_THREADS 1024
#ifdef __HIP_PLATFORM_AMD__
#define ATOMIC_CONSUMER(chunk) \
if (counters) { \
if (threadIdx.x == 0 && blockIdx.x == 0) { \
while (0 != (atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \
} \
((unsigned int *)counters)[chunk] = 1; \
__threadfence_system(); \
} \
if (blockIdx.x == 0) __syncthreads(); \
}
#else
#define ATOMIC_CONSUMER(chunk) \
if (counters) { \
if (threadIdx.x == 0 && blockIdx.x == 0) { \
......@@ -34,6 +48,7 @@
} \
if (blockIdx.x == 0) __syncthreads(); \
}
#endif
#define ATOMIC_PRODUCER(chunk) \
if (counters) { \
......@@ -1025,7 +1040,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
// reset counter for next producer.
((unsigned int *)counters)[0] = 1;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system();
// __threadfence();
// __syncthreads()
#else
asm volatile("fence.sc.gpu;\n");
#endif
}
}
__syncthreads();
......@@ -1116,7 +1137,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
// reset counter for next producer.
((unsigned int *)counters)[chunk_i] = 1;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system();
// __threadfence();
// __syncthreads()
#else
asm volatile("fence.sc.gpu;\n");
#endif
}
}
__syncthreads();
......@@ -1357,6 +1384,33 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
} // fp16 inplace allgather kernel (Volta,Hopper)
#ifdef __HIP_PLATFORM_AMD__
#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[2]; \
attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \
attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \
attribute_ub[1].val.clusterDim.y = 1; \
attribute_ub[1].val.clusterDim.z = 1; \
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = 1; // dtk unsupport hipLaunchAttributeClusterDimension
#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event)
#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2
#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \
ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \
attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \
attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \
attribute_ub[1].val.clusterDim.y = 1; \
attribute_ub[1].val.clusterDim.z = 1; \
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = 1; // dtk unsupport hipLaunchAttributeClusterDimension
#else
#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[2]; \
......@@ -1389,6 +1443,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH;
#endif
#define callranks_ag(x) \
if (ar_nvsize == x) { \
......@@ -1932,6 +1987,53 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
}
}
#ifdef __HIP_PLATFORM_AMD__
template void reducescatter2_userbuff_stridedoutput_fp8<hip_f8<hip_f8_type::bf8>>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_stridedoutput_fp8<hip_f8<hip_f8_type::fp8>>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);
template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
reducescatter2_userbuff_stridedoutput_fp8<fp8type>(output, scale, handler, offset, elements, 1, 0,
comm, stream, comm_launch_event);
}
template void reducescatter2_userbuff_fp8<hip_f8<hip_f8_type::bf8>>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_fp8<hip_f8<hip_f8_type::fp8>>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_strided_atomic_fp8<hip_f8<hip_f8_type::fp8>>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_atomic_fp8<hip_f8<hip_f8_type::bf8>>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<hip_f8<hip_f8_type::fp8>>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<hip_f8<hip_f8_type::bf8>>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
#else
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
......@@ -1977,6 +2079,7 @@ template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
#endif
__global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) {
atomicAdd_system(flagptr, 1);
......@@ -2196,7 +2299,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
// Decrement atomic val to signal current output tile finish
if (counters) {
((unsigned int *)counters)[0] = 0;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system();
// __threadfence();
// __syncthreads()
#else
asm volatile("fence.sc.gpu;\n");
#endif
}
}
}
......@@ -2267,7 +2376,13 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat
// Decrement atomic val to signal current output tile finish
if (counters) {
((unsigned int *)counters)[recv_chunk_id /*chunk_i+1*/] = 0;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system();
// __threadfence();
// __syncthreads()
#else
asm volatile("fence.sc.gpu;\n");
#endif
}
}
......@@ -2545,7 +2660,13 @@ static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) {
// COMM kernel need to explicitely flash gmem.
// GEMM kernel already executed, and can not see gmem
// change without COMM kernel explicitely make change
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system();
// __threadfence();
// __syncthreads()
#else
asm volatile("fence.sc.gpu;\n");
#endif
}
// consumer
......@@ -2555,7 +2676,13 @@ static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) {
while (0 != (atomicCAS((unsigned int *)atomic_ptr + chunk_i, 0, 0))) {
}
((unsigned int *)atomic_ptr)[chunk_i] = 1;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system();
// __threadfence();
// __syncthreads()
#else
asm volatile("fence.sc.gpu;\n");
#endif
}
}
......@@ -2567,7 +2694,13 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i
while (0 != (atomicCAS((unsigned int *)atomic_ptr + i, 0, 0))) {
}
((unsigned int *)atomic_ptr)[i] = 1;
#ifdef __HIP_PLATFORM_AMD__
__threadfence_system();
// __threadfence();
// __syncthreads()
#else
asm volatile("fence.sc.gpu;\n");
#endif
}
}
}
......@@ -2661,12 +2794,21 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
num_aligned_elements_per_input, tot_input_size);
}
#ifdef __HIP_PLATFORM_AMD__
template void reduce_fp8_in_bf16_out<hip_f8<hip_f8_type::fp8>>(void *inputs, void *output, float *scale,
int num_inputs, int input_size,
cudaStream_t stream);
template void reduce_fp8_in_bf16_out<hip_f8<hip_f8_type::bf8>>(void *inputs, void *output, float *scale,
int num_inputs, int input_size,
cudaStream_t stream);
#else
template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale,
int num_inputs, int input_size,
cudaStream_t stream);
template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, float *scale,
int num_inputs, int input_size,
cudaStream_t stream);
#endif
template <int nvec>
__global__ void __launch_bounds__(MAX_THREADS / 4)
......
......@@ -105,14 +105,22 @@ struct communicator {
int memflags[NVTE_MAX_REGIONS]; // UC,MC, user/lib allocated
#ifdef __HIP_PLATFORM_AMD__
hipMemGenericAllocationHandle_t *uchandles[NVTE_MAX_REGIONS];
#else
CUmemGenericAllocationHandle *uchandles[NVTE_MAX_REGIONS];
#endif
void *ucbase_ptr[NVTE_MAX_REGIONS]; // only for cuMem allocated memory
size_t mem_size[NVTE_MAX_REGIONS];
bool mem_dealloc[NVTE_MAX_REGIONS];
void *mc_ptr[NVTE_MAX_REGIONS];
void *mc_baseptr;
#ifdef __HIP_PLATFORM_AMD__
hipMemGenericAllocationHandle_t mc_handle;
#else
CUmemGenericAllocationHandle mc_handle;
#endif
size_t mc_offset, mc_maxsize;
int use_mc; // 1: use MC if available, 0: override not to use MC
......
......@@ -36,6 +36,9 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
}
void checkCuDriverContext(CUstream stream) {
#ifdef __HIP_PLATFORM_AMD__
return;
#else
CUcontext ctx;
const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx);
switch (driver_status) {
......@@ -54,8 +57,10 @@ void checkCuDriverContext(CUstream stream) {
cuda_driver::call("cuGetErrorString", driver_status, &desc_NVTE_CHECK_CUDA_DRIVER);
NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER);
}
#endif
}
#ifndef __HIP_PLATFORM_AMD__
CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
static const std::unordered_map<DType, CUtensorMapDataType> dtypeMapping = {
{DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
......@@ -127,11 +132,16 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
// Any element that is outside of bounds will be set to zero by the TMA transfer.
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
}
#endif
bool is_supported_by_CC_100() {
#ifdef __HIP_PLATFORM_AMD__
return false;
#else
int deviceComputeCapability = cuda::sm_arch(cuda::current_device());
return deviceComputeCapability >= 100;
#endif
}
} // namespace transformer_engine
......@@ -6,8 +6,9 @@
#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
......@@ -217,9 +218,15 @@ using int32 = int32_t;
using int64 = int64_t;
using fp32 = float;
using fp16 = half;
#ifndef __HIP_PLATFORM_AMD__
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
#else
using bf16 = hip_bfloat16;
using fp8e4m3 = hip_f8<hip_f8_type::fp8>;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>;
#endif
#if CUDA_VERSION >= 12080
using fp8e8m0 = __nv_fp8_e8m0;
#endif
......@@ -239,9 +246,15 @@ TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
TRANSFORMER_ENGINE_TYPE_NAME(float)
TRANSFORMER_ENGINE_TYPE_NAME(half)
#ifdef __HIP_PLATFORM_AMD__
TRANSFORMER_ENGINE_TYPE_NAME(hip_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(hip_f8<hip_f8_type::fp8>)
TRANSFORMER_ENGINE_TYPE_NAME(hip_f8<hip_f8_type::bf8>)
#else
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
#endif
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif
......@@ -510,6 +523,7 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream);
void checkCuDriverContext(CUstream stream);
#ifndef __HIP_PLATFORM_AMD__
CUtensorMapDataType get_CUtensorMapDataType(DType dtype);
// Set up parameters to create TMA descriptor.
......@@ -517,6 +531,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_size);
#endif
bool is_supported_by_CC_100();
......
......@@ -11,6 +11,7 @@
namespace transformer_engine {
#ifndef USE_ROCM
// get cuDNN data type
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
......@@ -60,6 +61,7 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t)
void nvte_cudnn_handle_init() {
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
}
#endif
} // namespace transformer_engine
......
......@@ -7,9 +7,11 @@
#ifndef TRANSFORMER_ENGINE_CUDNN_UTILS_H_
#define TRANSFORMER_ENGINE_CUDNN_UTILS_H_
#ifndef __HIP_PLATFORM_AMD__
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#endif
#include <cstdint>
#include <mutex>
......@@ -18,6 +20,7 @@
namespace transformer_engine {
#ifndef __HIP_PLATFORM_AMD__
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
......@@ -40,6 +43,7 @@ class cudnnExecutionPlanManager {
private:
cudnnHandle_t handle_ = nullptr;
};
#endif
} // namespace transformer_engine
......
......@@ -4,9 +4,15 @@
* See LICENSE for license information.
************************************************************************/
#ifndef __HIP_PLATFORM_AMD__
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda.h>
#else
#include <iostream>
#include "hipblas_gemm.h"
#include "rocm_gemm.hip"
#endif // #ifndef __HIP_PLATFORM_AMD__
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
......@@ -17,6 +23,7 @@
#include "../util/logging.h"
#include "common/util/cuda_runtime.h"
#ifndef __HIP_PLATFORM_AMD__
namespace {
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
......@@ -137,9 +144,19 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
}
} // namespace
#endif // __HIP_PLATFORM_AMD__
namespace transformer_engine {
#ifdef __HIP_PLATFORM_AMD__
//Forward declaration. The implementation is in rocm_gemm.cu
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, bool transa, bool transb, bool grad,
void* workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, int m_split, int n_split, bool gemm_producer,
const Tensor *inputCounter, hipStream_t stream);
#else // Use cublasLt
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad,
......@@ -424,6 +441,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
}
#endif // __HIP_PLATFORM_AMD__
static std::once_flag init_flag;
static cudaStream_t compute_streams[num_streams];
......@@ -437,6 +455,19 @@ static void init_streams_and_events() {
}
}
// Add for batchgemm
static std::once_flag init_flag_batchgemm;
static cudaStream_t compute_streams_batchgemm[num_batchgemm_streams];
static cudaEvent_t cublas_event_batchgemm[num_batchgemm_streams];
// Warning: only call once per device!
static void init_streams_and_events_batchgemm() {
for (int i = 0; i < num_batchgemm_streams; i++) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams_batchgemm[i], cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event_batchgemm[i]));
}
}
} // namespace transformer_engine
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
......@@ -477,10 +508,57 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTE_ERROR("TT layout not allowed.");
}
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
const char *NVTE_BLASLT_BLAS = std::getenv("NVTE_FORCE_BLASLT");
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_BLASLT_BLAS != nullptr && NVTE_BLASLT_BLAS[0] == '1')){
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#endif //USE_HIPBLASLT
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad,
#endif //__HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
transa, transb,
#else
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
#endif //__HIP_PLATFORM_AMD__
grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, 0, 0, false, nullptr, stream);
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
}
else{
hipblas_gemm(inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
0,
0,
false,
nullptr,
stream);
}
#endif //USE_HIPBLASLT
#endif //__HIP_PLATFORM_AMD__
}
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
......@@ -491,10 +569,12 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_atomic_gemm);
#ifndef __HIP_PLATFORM_AMD__
int cudart_version;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm.");
NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm.");
#endif
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
......@@ -529,10 +609,103 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
NVTE_ERROR("TT layout not allowed.");
}
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
const char *NVTE_BLASLT_BLAS = std::getenv("NVTE_FORCE_BLASLT");
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_BLASLT_BLAS != nullptr && NVTE_BLASLT_BLAS[0] == '1')){
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#endif //USE_HIPBLASLT
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad,
#endif //__HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
transa, transb,
#else
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
#endif //__HIP_PLATFORM_AMD__
grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
#ifdef __HIP_PLATFORM_AMD__
#ifdef USE_HIPBLASLT
}
else{
hipblas_gemm(
inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
m_split,
n_split,
gemm_producer,
inputCounter,
stream);
}
#endif //USE_HIPBLASLT
#endif //__HIP_PLATFORM_AMD__
}
void nvte_cublaslt_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublaslt_gemm);
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
Tensor *outputD = reinterpret_cast<Tensor *>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd,
#ifdef __HIP_PLATFORM_AMD__
transa, transb,
#else
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
#endif //__HIP_PLATFORM_AMD__
grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, 0, 0, false, nullptr, stream);
}
void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
......@@ -552,12 +725,30 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
}
const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM");
const char *NVTE_BLASLT_BLAS = std::getenv("NVTE_FORCE_BLASLT");
bool NVTE_FORCE_BLASLT_MULSTREAM;
if(NVTE_BLAS_MULSTREAM==nullptr){
NVTE_FORCE_BLASLT_MULSTREAM = true;
} elif((NVTE_BLASLT_BLAS != nullptr && NVTE_BLASLT_BLAS[0] == '1') && (NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1')){
NVTE_ERROR("NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_BLASLT can't be set at the same time.");
} else{
NVTE_FORCE_BLASLT_MULSTREAM = false;
}
if (NVTE_FORCE_BLASLT_MULSTREAM){
for (int i = 0; i < num_gemms; i++) {
nvte_cublaslt_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams]);
}
} else{
for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams]);
}
}
// record events on compute streams
for (int s = 0; s < num_stream_used; s++) {
......@@ -568,3 +759,107 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
}
}
#ifdef __HIP_PLATFORM_AMD__
void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
NVTETensor *workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm);
using namespace transformer_engine;
static_assert(num_gemms % num_batchgemm_streams == 0,
"Need num_gemms mod num_batchgemm_streams == 0.");
static int batch_count = num_gemms / num_batchgemm_streams;
// Inits streams and events (once, globally)
std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm);
int num_stream_used = num_batchgemm_streams;
// wait for current stream to finish
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[0], stream));
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams_batchgemm[s], cublas_event_batchgemm[0]));
}
for (int i = 0; i < num_stream_used; i++) {
nvte_cublas_batchgemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_batchgemm_streams], accumulate, use_split_accumulator, math_sm_count,
batch_count, compute_streams_batchgemm[i % num_batchgemm_streams]);
}
// record events on compute streams
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[s], compute_streams_batchgemm[s]));
}
// wait for all compute streams to finish
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event_batchgemm[s]));
}
}
// add for batchgemm
void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_batchgemm);
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
Tensor *outputD = reinterpret_cast<Tensor *>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
int m, n, k;
if (!transa && transb) {
// for NT
m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
} else if(transa && !transb){
// for TN
m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
} else if(!transa && !transb){
// for NN
m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
hipblas_batchgemm(inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
0,
0,
false,
nullptr,
batch_count,
stream);
}
#endif
\ No newline at end of file
/*************************************************************************
* Copyright (c) 2022-2024, S3000 qianyj. All rights reserved.
************************************************************************/
#include <hip/hip_runtime.h>
#include "hipblas_gemm.h"
#include "../common_hip.h"
#include "../util/logging.h"
namespace {
hipblasDatatype_t get_hip_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kFloat16:
return HIPBLAS_R_16F;
case DType::kFloat32:
return HIPBLAS_R_32F;
case DType::kBFloat16:
return HIPBLAS_R_16B;
default:
NVTE_ERROR("Invalid type");
}
}
} // namespace
// Define a static handle manager
static HipblasHandleManager handleManager;
namespace transformer_engine {
void hipblas_gemm(const Tensor *inputA,
const Tensor *inputB,
Tensor *outputD,
const Tensor *inputBias,
Tensor *outputPreGelu,
int m, int n, int k,
int lda, int ldb, int ldd,
hipblasOperation_t transa,
hipblasOperation_t transb,
bool grad,
void* workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
hipStream_t stream) {
// Use static handles
int device_id;
hipGetDevice(&device_id);
hipblasHandle_t handle = handleManager.get(device_id);
void *A = inputA->data.dptr;
// void *A_scale_inverse = inputA->scale_inv.dptr;
void *B = inputB->data.dptr;
// void *B_scale_inverse = inputB->scale_inv.dptr;
void *C = outputD->data.dptr;
void *D = outputD->data.dptr;
// Select the calculation accuracy
hipblasDatatype_t A_type = get_hip_dtype(inputA->data.dtype);
hipblasDatatype_t B_type = get_hip_dtype(inputB->data.dtype);
hipblasDatatype_t D_type = get_hip_dtype(outputD->data.dtype);
hipblasDatatype_t computeType = HIPBLAS_R_32F; // default acc is float32
// setting computetype
// if (/* condition for mixed precision */) {
// computeType = HIPBLAS_R_16F; //
// }
// hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F;
// const char *env_tf32 = std::getenv("NVTE_BLASLT_TF32");
// if (env_tf32 != nullptr && env_tf32[0] == '1') {
// if (A_type == HIPBLAS_R_32F && B_type == HIPBLAS_R_32F && D_type == HIPBLAS_R_32F) {
// gemm_compute_type = HIPBLAS_COMPUTE_32F_FAST_TF32;
// }
float one = 1.0f;
float zero = 0.0f;
float beta = accumulate ? one : zero;
hipblasSetStream(handle, stream);
// execute multiply
hipblasStatus_t status = hipblasGemmEx(
handle,
transa, // transa
transb, // transb
m,
n,
k,
static_cast<const void*>(&one),
A,
A_type,
lda,
B,
B_type,
ldb,
static_cast<const void*>(&beta),
D,
D_type,
ldd,
computeType,
HIPBLAS_GEMM_DEFAULT);
if (status != HIPBLAS_STATUS_SUCCESS) {
NVTE_ERROR("hipblasGemmEx execution failed");
}
}
void hipblas_batchgemm(const Tensor *inputA,
const Tensor *inputB,
Tensor *outputD,
const Tensor *inputBias,
Tensor *outputPreGelu,
int m, int n, int k,
int lda, int ldb, int ldd,
hipblasOperation_t transa,
hipblasOperation_t transb,
bool grad,
void* workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
int batch_count,
hipStream_t stream) {
// Use static handles
int device_id;
hipGetDevice(&device_id);
hipblasHandle_t handle = handleManager.get(device_id);
void *A = inputA->data.dptr;
// void *A_scale_inverse = inputA->scale_inv.dptr;
void *B = inputB->data.dptr;
// void *B_scale_inverse = inputB->scale_inv.dptr;
void *C = outputD->data.dptr;
void *D = outputD->data.dptr;
// Select the calculation accuracy
hipblasDatatype_t A_type = get_hip_dtype(inputA->data.dtype);
hipblasDatatype_t B_type = get_hip_dtype(inputB->data.dtype);
hipblasDatatype_t D_type = get_hip_dtype(outputD->data.dtype);
hipblasDatatype_t computeType = HIPBLAS_R_32F; // default acc is float32
float one = 1.0f;
float zero = 0.0f;
float beta = accumulate ? one : zero;
hipblasSetStream(handle, stream);
// execute multiply
// calculate stride
const long long int strideA = m*k;
const long long int strideB = k*n;
const long long int strideD = m*n;
hipblasStatus_t status = hipblasGemmStridedBatchedEx(
handle,
transa, // transa
transb, // transb
m,
n,
k,
static_cast<const void*>(&one),
A,
A_type,
lda,
strideA,
B,
B_type,
ldb,
strideB,
static_cast<const void*>(&beta),
D,
D_type,
ldd,
strideD,
batch_count,
computeType,
HIPBLAS_GEMM_DEFAULT);
if (status != HIPBLAS_STATUS_SUCCESS) {
NVTE_ERROR("hipblasGemmEx execution failed");
}
}
} // namespace transformer_engine
\ No newline at end of file
/*************************************************************************
* Copyright (c) 2022-2024, S3000 qianyj. All rights reserved.
************************************************************************/
/*! \file hipblas_gemmn.h
* \brief Functions for blas instead blaslt in pure gemm
*/
#ifndef TRANSFORMER_ENGINE_COMMON_HIPBLAS_GEMM_H_
#define TRANSFORMER_ENGINE_COMMON_HIPBLAS_GEMM_H_
#include <hip/hip_runtime.h>
#ifdef USE_HIPBLASLT
#include <hipblas/hipblas.h>
#include <mutex>
#else
#include <rocblas/rocblas.h>
#endif
#include <stdexcept>
#include "../common_hip.h"
#include <iostream>
#ifdef USE_HIPBLASLT
class HipblasHandleManager {
public:
HipblasHandleManager() {}
~HipblasHandleManager() {
// Release all handles when the manager is destroyed
for (auto& device_pair : handles_map_) {
hipblasDestroy(device_pair.second); // Only one handle per device
}
}
// Get a handle for the given device (creates if necessary)
hipblasHandle_t get(int device_id) {
std::lock_guard<std::mutex> lock(mutex_);
// Check if the handle for this device exists
auto device_it = handles_map_.find(device_id);
if (device_it != handles_map_.end()) {
return device_it->second;
}
// Create a new handle for this device if it doesn't exist
hipblasHandle_t handle;
hipblasStatus_t status = hipblasCreate(&handle);
if (status != HIPBLAS_STATUS_SUCCESS) {
throw std::runtime_error("Failed to create HIPBLAS handle");
}
// Store the handle in the map for this device
handles_map_[device_id] = handle;
return handle;
}
private:
std::unordered_map<int, hipblasHandle_t> handles_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
};
namespace transformer_engine {
void hipblas_gemm(const Tensor *inputA,
const Tensor *inputB,
Tensor *outputD,
const Tensor *inputBias,
Tensor *outputPreGelu,
int m, int n, int k,
int lda, int ldb, int ldd,
hipblasOperation_t transa,
hipblasOperation_t transb,
bool grad,
void* workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
hipStream_t stream);
void hipblas_batchgemm(const Tensor *inputA,
const Tensor *inputB,
Tensor *outputD,
const Tensor *inputBias,
Tensor *outputPreGelu,
int m, int n, int k,
int lda, int ldb, int ldd,
hipblasOperation_t transa,
hipblasOperation_t transb,
bool grad,
void* workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
int batch_count,
hipStream_t stream);
}
#else
class HipblasHandleManager {
public:
HipblasHandleManager() : handle_(nullptr) {}
~HipblasHandleManager() {
// Release the handle in the destructor to ensure cleanup when it's no longer needed
if (handle_ != nullptr) {
rocblas_destroy_handle(handle_);
}
}
// Get a handle to make sure it's valid every time
rocblas_handle get() {
if (handle_ == nullptr) {
createHandle();
}
// Check whether the handle is created successfully
assert(handle_ != nullptr && "hipblasHandle should not be null after creation");
return handle_;
}
private:
rocblas_handle handle_;
//
void createHandle() {
// A private method that creates a handle
rocblas_status status = rocblas_create_handle(&handle_);
if (status != rocblas_status_success) {
// If initialization fails, an exception is thrown
throw std::runtime_error("Failed to create HIPBLAS handle");
}
}
// Copy construct and assignment operations are prohibited
HipblasHandleManager(const HipblasHandleManager&) = delete;
HipblasHandleManager& operator=(const HipblasHandleManager&) = delete;
};
#endif // #ifdef USE_HIPBLASLT
#endif // TRANSFORMER_ENGINE_COMMON_HIPBLAS_GEMM_H_
\ No newline at end of file
This diff is collapsed.
......@@ -109,6 +109,21 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVT
NVTETensor* workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
#ifdef __HIP_PLATFORM_AMD__
void nvte_multi_stream_cublas_batchgemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
const NVTETensor* bias, NVTETensor* pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
NVTETensor* workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream);
#endif
#ifdef __cplusplus
} // extern "C"
#endif
......@@ -116,8 +131,14 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVT
/*! \namespace transformer_engine
*/
namespace transformer_engine {
#ifdef __HIP_PLATFORM_AMD__
// In dcu, 2 stream is more better
constexpr int num_streams = 2;
// Add for batchgemm stream
constexpr int num_batchgemm_streams = 1;
#else
constexpr int num_streams = 4;
#endif
} // namespace transformer_engine
......
......@@ -39,10 +39,12 @@ Compute always in FP32
namespace transformer_engine {
namespace normalization {
#ifndef __HIP_PLATFORM_AMD__
cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) {
return training ? cudnn_frontend::NormFwdPhase_t::TRAINING
: cudnn_frontend::NormFwdPhase_t::INFERENCE;
}
#endif
TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype,
......@@ -195,6 +197,10 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
_training(training),
_norm_stage(NormStage),
_norm_type(NormType) {
#ifdef USE_ROCM
static_assert(false,
"Cudnn backend is not surpported in rocm for normalization yet.");
#else
static_assert(CUDNN_FRONTEND_VERSION >= 10601,
"CUDNN_FRONTEND_VERSION should be at least 1.6.1!");
......@@ -378,9 +384,14 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
}
// Build the graph
this->_build();
#endif
}
void CudnnNormalizationPlan::_build() {
#ifdef USE_ROCM
static_assert(false,
"Cudnn backend is not surpported in rocm for normalization yet.");
#else
NVTE_CHECK(_graph.validate().is_good());
NVTE_CHECK(_graph.build_operation_graph(_handle).is_good());
NVTE_CHECK(_graph
......@@ -390,15 +401,25 @@ void CudnnNormalizationPlan::_build() {
NVTE_CHECK(_graph.check_support(_handle).is_good());
NVTE_CHECK(
_graph.build_plans(_handle, cudnn_frontend::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good());
#endif
}
std::vector<size_t> CudnnNormalizationPlan::getWorkspaceShape() const {
#ifdef USE_ROCM
static_assert(false,
"Cudnn backend is not surpported in rocm for normalization yet.");
#else
return {static_cast<size_t>(_graph.get_workspace_size())};
#endif
}
void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr,
void* mean_dptr, void* eps_dptr, void* rsigma_dptr,
void* workspace_dptr, cudaStream_t stream) {
#ifdef USE_ROCM
static_assert(false,
"Cudnn backend is not surpported in rocm for normalization yet.");
#else
// Binding data pointers to graph tensors
_variant_pack = {{_x, x_dptr}, {_eps, eps_dptr}};
......@@ -433,12 +454,17 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
// Execute the computation
NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream));
NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good());
#endif
}
void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr,
void* rsigma_dptr, void* dx_dptr, void* dz_dptr,
void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
cudaStream_t stream) {
#ifdef USE_ROCM
static_assert(false,
"Cudnn backend is not surpported in rocm for normalization yet.");
#else
// Binding data pointers to graph tensors
_variant_pack = {
{_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}};
......@@ -455,6 +481,7 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_
// Execute the computation
NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream));
NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good());
#endif
}
NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
......@@ -491,13 +518,21 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
}
bool& _cudnn_norm_fwd_flag() {
#ifdef USE_ROCM
return false;
#else
static bool flag = transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN");
return flag;
#endif
}
bool& _cudnn_norm_bwd_flag() {
#ifdef USE_ROCM
return false;
#else
static bool flag = transformer_engine::getenv<bool>("NVTE_NORM_BWD_USE_CUDNN");
return flag;
#endif
}
bool use_cudnn_norm_fwd() { return _cudnn_norm_fwd_flag(); }
......@@ -508,10 +543,18 @@ bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); }
void nvte_enable_cudnn_norm_fwd(bool enable) {
NVTE_API_CALL(nvte_enable_cudnn_norm_fwd);
#ifdef USE_ROCM
transformer_engine::normalization::_cudnn_norm_bwd_flag() = false;
#else
transformer_engine::normalization::_cudnn_norm_fwd_flag() = enable;
#endif
}
void nvte_enable_cudnn_norm_bwd(bool enable) {
NVTE_API_CALL(nvte_enable_cudnn_norm_bwd);
#ifdef USE_ROCM
transformer_engine::normalization::_cudnn_norm_bwd_flag() = false;
#else
transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable;
#endif
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment